diff --git a/src/ggml-cpu/ggml-cpu.c b/src/ggml-cpu/ggml-cpu.c index e809f05d2..52f6210ae 100644 --- a/src/ggml-cpu/ggml-cpu.c +++ b/src/ggml-cpu/ggml-cpu.c @@ -4930,13 +4930,18 @@ static void ggml_compute_forward_acc_f32( bool inplace = (bool) ((int32_t *) dst->op_params)[4]; if (!inplace) { - if (params->ith == 0) { + size_t total = ggml_nbytes(dst); + size_t stride = (total + params->nth - 1) / params->nth; + size_t start = params->ith * stride; + if (total > start) + { + size_t rest = total - start; + size_t bytes = rest < stride ? rest : stride; + char* dstp = ((char*)dst->data) + start; + char* srcp = ((char*)src0->data) + start; // memcpy needs to be synchronized across threads to avoid race conditions. // => do it in INIT phase - memcpy( - ((char *) dst->data), - ((char *) src0->data), - ggml_nbytes(dst)); + memcpy(dstp, srcp, bytes); } ggml_barrier(params->threadpool); } @@ -7827,8 +7832,15 @@ static void ggml_compute_forward_out_prod_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows - if (ith == 0) { - ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + size_t total = ne0*ne1*ne2*ne3; + size_t offset = (total + params->nth - 1) / params->nth; + size_t start = params->ith * offset; + if (total > start) + { + size_t rest = total - start; + size_t floats = rest < offset ? rest : offset; + void* dstp = (float*)(dst->data) + start; + ggml_vec_set_f32(floats, dstp, 0); } ggml_barrier(params->threadpool); @@ -7949,8 +7961,15 @@ static void ggml_compute_forward_out_prod_q_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows - if (ith == 0) { - ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + size_t total = ne0*ne1*ne2*ne3; + size_t offset = (total + params->nth - 1) / params->nth; + size_t start = params->ith * offset; + if (total > start) + { + size_t rest = total - start; + size_t floats = rest < offset ? rest : offset; + void* dstp = (float*)(dst->data) + start; + ggml_vec_set_f32(floats, dstp, 0); } ggml_barrier(params->threadpool); @@ -8128,14 +8147,20 @@ static void ggml_compute_forward_set_f32( size_t offset = ((int32_t *) dst->op_params)[3]; bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - if (!inplace) { - if (params->ith == 0) { + if (inplace) + { + size_t total = ggml_nbytes(dst); + size_t stride = (total + params->nth - 1) / params->nth; + size_t start = params->ith * stride; + if (total > start) + { + size_t rest = total - start; + size_t bytes = rest < stride ? rest : stride; + char* dstp = ((char*)dst->data) + start; + char* srcp = ((char*)src0->data) + start; // memcpy needs to be synchronized across threads to avoid race conditions. // => do it in INIT phase - memcpy( - ((char *) dst->data), - ((char *) src0->data), - ggml_nbytes(dst)); + memcpy(dstp, srcp, bytes); } ggml_barrier(params->threadpool); } @@ -8200,13 +8225,18 @@ static void ggml_compute_forward_set_i32( bool inplace = (bool) ((int32_t *) dst->op_params)[4]; if (!inplace) { - if (params->ith == 0) { + size_t total = ggml_nbytes(dst); + size_t stride = (total + params->nth - 1) / params->nth; + size_t start = params->ith * stride; + if (total > start) + { + size_t rest = total - start; + size_t bytes = rest < stride ? rest : stride; + char* dstp = ((char*)dst->data) + start; + char* srcp = ((char*)src0->data) + start; // memcpy needs to be synchronized across threads to avoid race conditions. // => do it in INIT phase - memcpy( - ((char *) dst->data), - ((char *) src0->data), - ggml_nbytes(dst)); + memcpy(dstp, srcp, bytes); } ggml_barrier(params->threadpool); } @@ -8328,7 +8358,7 @@ static void ggml_compute_forward_reshape( static void ggml_compute_forward_view( const struct ggml_compute_params * params, - const struct ggml_tensor * dst) { + struct ggml_tensor * dst) { // NOP UNUSED(params); UNUSED(dst); @@ -8338,7 +8368,7 @@ static void ggml_compute_forward_view( static void ggml_compute_forward_permute( const struct ggml_compute_params * params, - const struct ggml_tensor * dst) { + struct ggml_tensor * dst) { // NOP UNUSED(params); UNUSED(dst); @@ -8348,7 +8378,7 @@ static void ggml_compute_forward_permute( static void ggml_compute_forward_transpose( const struct ggml_compute_params * params, - const struct ggml_tensor * dst) { + struct ggml_tensor * dst) { // NOP UNUSED(params); UNUSED(dst); @@ -8779,15 +8809,20 @@ static void ggml_compute_forward_diag_mask_f32( GGML_ASSERT(n_past >= 0); if (!inplace) { - if (ith == 0) { + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + size_t total = ggml_nbytes(dst); + size_t offset = (total + params->nth - 1) / params->nth; + size_t start = params->ith * offset; + if (total > start) + { + size_t rest = total - start; + size_t bytes = rest < offset ? rest : offset; + char* dstp = ((char*)dst->data) + start; + char* srcp = ((char*)src0->data) + start; // memcpy needs to be synchronized across threads to avoid race conditions. // => do it in INIT phase - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - memcpy( - ((char *) dst->data), - ((char *) src0->data), - ggml_nbytes(dst)); + memcpy(dstp, srcp, bytes); } ggml_barrier(params->threadpool); } @@ -11029,11 +11064,13 @@ static void ggml_compute_forward_flash_attn_ext_f16( static void ggml_compute_forward_flash_attn_ext( const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const struct ggml_tensor * mask, struct ggml_tensor * dst) { + + const struct ggml_tensor* q = dst->src[0]; + const struct ggml_tensor* k = dst->src[1]; + const struct ggml_tensor* v = dst->src[2]; + const struct ggml_tensor* mask = dst->src[3]; + switch (dst->op_params[3]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: @@ -11106,8 +11143,17 @@ static void ggml_compute_forward_flash_attn_back_f32( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - if (ith == 0) { - memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); + size_t total = nb0 * ne0 * ne1 * ne2 * ne3; + size_t offset = (total + params->nth - 1) / params->nth; + size_t start = params->ith * offset; + if (total > start) + { + size_t rest = total - start; + size_t bytes = rest < offset ? rest : offset; + char* dstp = ((char*)dst->data) + start; + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memset(dstp, 0, bytes); } ggml_barrier(params->threadpool); @@ -11367,9 +11413,10 @@ static void ggml_compute_forward_flash_attn_back_f32( static void ggml_compute_forward_flash_attn_back( const struct ggml_compute_params * params, - const bool masked, struct ggml_tensor * dst) { - + int32_t t = ggml_get_op_params_i32(dst, 0); + GGML_ASSERT(t == 0 || t == 1); + const bool masked = t != 0; const struct ggml_tensor * q = dst->src[0]; switch (q->type) { @@ -11801,8 +11848,16 @@ static void ggml_compute_forward_add_rel_pos_f32( const bool inplace = (bool) ((int32_t *) dst->op_params)[0]; if (!inplace) { - if (params->ith == 0) { - memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst)); + size_t total = ggml_nbytes(dst); + size_t offset = (total + params->nth - 1) / params->nth; + size_t start = params->ith * offset; + if (total > start) + { + size_t rest = total - start; + size_t bytes = rest < offset ? rest : offset; + char* dstp = ((char*)dst->data) + start; + char* srcp = ((char*)src0->data) + start; + memcpy(dstp, srcp, bytes); } ggml_barrier(params->threadpool); } @@ -11907,8 +11962,15 @@ static void ggml_compute_forward_rwkv_wkv6_f32( GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS size_t h_stride_2d = head_size * head_size; - if (ith == 0) { - memset(dst_data, 0, T * C * sizeof(float)); + size_t total = T * C * sizeof(float); + size_t offset = (total + params->nth - 1) / params->nth; + size_t start = params->ith * offset; + if (total > start) + { + size_t rest = total - start; + size_t bytes = rest < offset ? rest : offset; + char* dstp = ((char*)dst->data) + start; + memset(dstp, 0, bytes); } ggml_barrier(params->threadpool); @@ -12109,8 +12171,15 @@ static void ggml_compute_forward_gla_f32( GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS size_t h_stride_2d = head_size * head_size; - if (ith == 0) { - memset(dst_data, 0, T * C * sizeof(float)); + size_t total = T * C * sizeof(float); + size_t offset = (total + params->nth - 1) / params->nth; + size_t start = params->ith * offset; + if (total > start) + { + size_t rest = total - start; + size_t bytes = rest < offset ? rest : offset; + char* dstp = ((char*)dst->data) + start; + memset(dstp, 0, bytes); } ggml_barrier(params->threadpool); @@ -12292,9 +12361,10 @@ static void ggml_compute_forward_map_unary_f32( static void ggml_compute_forward_map_unary( const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_unary_op_f32_t fun) { + struct ggml_tensor * dst) { + ggml_unary_op_f32_t fun; + memcpy(&fun, dst->op_params, sizeof(fun)); const struct ggml_tensor * src0 = dst->src[0]; switch (src0->type) { @@ -12341,9 +12411,10 @@ static void ggml_compute_forward_map_binary_f32( static void ggml_compute_forward_map_binary( const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_binary_op_f32_t fun) { - + struct ggml_tensor * dst) { + + ggml_binary_op_f32_t fun; + memcpy(&fun, dst->op_params, sizeof(fun)); const struct ggml_tensor * src0 = dst->src[0]; switch (src0->type) { @@ -12362,9 +12433,10 @@ static void ggml_compute_forward_map_binary( static void ggml_compute_forward_map_custom1_f32( const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom1_op_f32_t fun) { + struct ggml_tensor * dst) { + ggml_custom1_op_f32_t fun; + memcpy(&fun, dst->op_params, sizeof(fun)); const struct ggml_tensor * a = dst->src[0]; if (params->ith != 0) { @@ -12378,9 +12450,10 @@ static void ggml_compute_forward_map_custom1_f32( static void ggml_compute_forward_map_custom2_f32( const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom2_op_f32_t fun) { + struct ggml_tensor * dst) { + ggml_custom2_op_f32_t fun; + memcpy(&fun, dst->op_params, sizeof(fun)); const struct ggml_tensor * a = dst->src[0]; const struct ggml_tensor * b = dst->src[1]; @@ -12395,9 +12468,10 @@ static void ggml_compute_forward_map_custom2_f32( static void ggml_compute_forward_map_custom3_f32( const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom3_op_f32_t fun) { + struct ggml_tensor * dst) { + ggml_custom3_op_f32_t fun; + memcpy(&fun, dst->op_params, sizeof(fun)); const struct ggml_tensor * a = dst->src[0]; const struct ggml_tensor * b = dst->src[1]; const struct ggml_tensor * c = dst->src[1]; @@ -12718,7 +12792,121 @@ static void ggml_compute_forward_opt_step_adamw( } ///////////////////////////////// -static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { +static void ggml_compute_forward_none( + const struct ggml_compute_params* params, + struct ggml_tensor* dst) { + UNUSED(params); + UNUSED(dst); +} + +static void ggml_compute_forward_error( + const struct ggml_compute_params* params, + struct ggml_tensor* dst) { + UNUSED(params); + UNUSED(dst); + GGML_ABORT("fatal error"); +} + +typedef void (*ggml_op_function)(const struct ggml_compute_params*, struct ggml_tensor*); +static ggml_op_function ggml_op_functions[GGML_OP_COUNT + 1] = +{ + /*GGML_OP_NONE */ ggml_compute_forward_none, + + /*GGML_OP_DUP*/ ggml_compute_forward_dup, + /*GGML_OP_ADD*/ ggml_compute_forward_add, + /*GGML_OP_ADD1*/ ggml_compute_forward_add1, + /*GGML_OP_ACC*/ ggml_compute_forward_acc, + /*GGML_OP_SUB*/ ggml_compute_forward_sub, + /*GGML_OP_MUL*/ ggml_compute_forward_mul, + /*GGML_OP_DIV*/ ggml_compute_forward_div, + /*GGML_OP_SQR*/ ggml_compute_forward_sqr, + /*GGML_OP_SQR*/ ggml_compute_forward_sqrt, + /*GGML_OP_LOG*/ ggml_compute_forward_log, + /*GGML_OP_SIN*/ ggml_compute_forward_sin, + /*GGML_OP_COS*/ ggml_compute_forward_cos, + /*GGML_OP_SUM*/ ggml_compute_forward_sum, + /*GGML_OP_SUM_ROWS*/ ggml_compute_forward_sum_rows, + /*GGML_OP_MEAN*/ ggml_compute_forward_mean, + /*GGML_OP_ARGMAX*/ ggml_compute_forward_argmax, + /*GGML_OP_COUNT_EQUAL*/ ggml_compute_forward_count_equal,//GGML_OP_COUNT_EQUAL, + /*GGML_OP_REPEAT*/ ggml_compute_forward_repeat, + /*GGML_OP_REPEAT_BACK*/ ggml_compute_forward_repeat_back, + /*GGML_OP_CONCAT*/ ggml_compute_forward_concat, + /*GGML_OP_SILU_BACK*/ ggml_compute_forward_silu_back, + /*GGML_OP_NORM*/ ggml_compute_forward_norm, // normalize + /*GGML_OP_RMS_NORM*/ ggml_compute_forward_rms_norm, + /*GGML_OP_RMS_NORM_BACK*/ ggml_compute_forward_rms_norm_back, + /*GGML_OP_GROUP_NORM*/ ggml_compute_forward_group_norm, + + /*GGML_OP_MUL_MAT*/ ggml_compute_forward_mul_mat, + /*GGML_OP_MUL_MAT_ID*/ ggml_compute_forward_mul_mat_id, + /*GGML_OP_OUT_PROD*/ ggml_compute_forward_out_prod, + + /*GGML_OP_SCALE*/ ggml_compute_forward_scale, + /*GGML_OP_SET*/ ggml_compute_forward_set, + /*GGML_OP_CPY*/ ggml_compute_forward_cpy, + /*GGML_OP_CONT*/ ggml_compute_forward_cont, + /*GGML_OP_RESHAPE*/ ggml_compute_forward_reshape, + /*GGML_OP_VIEW*/ ggml_compute_forward_view, + /*GGML_OP_PERMUTE*/ ggml_compute_forward_permute, + /*GGML_OP_TRANSPOSE*/ ggml_compute_forward_transpose, + /*GGML_OP_GET_ROWS*/ ggml_compute_forward_get_rows, + /*GGML_OP_GET_ROWS_BACK*/ ggml_compute_forward_get_rows_back, + /*GGML_OP_DIAG*/ ggml_compute_forward_diag, + /*GGML_OP_DIAG_MASK_INF*/ ggml_compute_forward_diag_mask_inf, + /*GGML_OP_DIAG_MASK_ZERO*/ ggml_compute_forward_diag_mask_zero, + /*GGML_OP_SOFT_MAX*/ ggml_compute_forward_soft_max, + /*GGML_OP_SOFT_MAX_BACK*/ ggml_compute_forward_soft_max_ext_back, + /*GGML_OP_ROPE*/ ggml_compute_forward_rope, + /*GGML_OP_ROPE_BACK*/ ggml_compute_forward_rope_back, + /*GGML_OP_CLAMP*/ ggml_compute_forward_clamp, + /*GGML_OP_CONV_TRANSPOSE_1D*/ ggml_compute_forward_conv_transpose_1d, + /*GGML_OP_IM2COL*/ ggml_compute_forward_im2col, + /*GGML_OP_IM2COL_BACK*/ ggml_compute_forward_im2col_back_f32, + /*GGML_OP_CONV_TRANSPOSE_2D*/ ggml_compute_forward_conv_transpose_2d, + /*GGML_OP_POOL_1D*/ ggml_compute_forward_pool_1d, + /*GGML_OP_POOL_2D*/ ggml_compute_forward_pool_2d, + /*GGML_OP_POOL_2D_BACK*/ ggml_compute_forward_pool_2d_back, + /*GGML_OP_UPSCALE*/ ggml_compute_forward_upscale, // nearest interpolate + /*GGML_OP_PAD*/ ggml_compute_forward_pad, + /*GGML_OP_PAD_REFLECT_1D*/ ggml_compute_forward_pad_reflect_1d, + /*GGML_OP_ARANGE*/ ggml_compute_forward_arange, + /*GGML_OP_TIMESTEP_EMBEDDING*/ ggml_compute_forward_timestep_embedding, + /*GGML_OP_ARGSORT*/ ggml_compute_forward_argsort, + /*GGML_OP_LEAKY_RELU*/ ggml_compute_forward_leaky_relu, + + /*GGML_OP_FLASH_ATTN_EXT*/ ggml_compute_forward_flash_attn_ext, + /*GGML_OP_FLASH_ATTN_BACK*/ ggml_compute_forward_flash_attn_back, + /*GGML_OP_SSM_CONV*/ ggml_compute_forward_ssm_conv, + /*GGML_OP_SSM_SCAN*/ ggml_compute_forward_ssm_scan, + /*GGML_OP_WIN_PART*/ ggml_compute_forward_win_part, + /*GGML_OP_WIN_UNPART*/ ggml_compute_forward_win_unpart, + /*GGML_OP_GET_REL_POS*/ ggml_compute_forward_get_rel_pos, + /*GGML_OP_ADD_REL_POS*/ ggml_compute_forward_add_rel_pos, + /*GGML_OP_RWKV_WKV6*/ ggml_compute_forward_rwkv_wkv6, + /*GGML_OP_GATED_LINEAR_ATTN*/ ggml_compute_forward_gla, + + /*GGML_OP_UNARY*/ ggml_compute_forward_unary, + + /*GGML_OP_MAP_UNARY*/ ggml_compute_forward_map_unary, + /*GGML_OP_MAP_BINARY*/ ggml_compute_forward_map_binary, + + /*GGML_OP_MAP_CUSTOM1_F32*/ ggml_compute_forward_map_custom1_f32, + /*GGML_OP_MAP_CUSTOM2_F32*/ ggml_compute_forward_map_custom2_f32, + /*GGML_OP_MAP_CUSTOM3_F32*/ ggml_compute_forward_map_custom3_f32, + + /*GGML_OP_MAP_CUSTOM1*/ ggml_compute_forward_map_custom1, + /*GGML_OP_MAP_CUSTOM2*/ ggml_compute_forward_map_custom2, + /*GGML_OP_MAP_CUSTOM3*/ ggml_compute_forward_map_custom3, + + /*GGML_OP_CROSS_ENTROPY_LOSS*/ ggml_compute_forward_cross_entropy_loss, + /*GGML_OP_CROSS_ENTROPY_LOSS_BACK*/ ggml_compute_forward_cross_entropy_loss_back, + /*GGML_OP_OPT_STEP_ADAMW*/ ggml_compute_forward_opt_step_adamw, + + /*GGML_OP_COUNT*/ ggml_compute_forward_error // keep this as the last entry +}; + +static void ggml_compute_forward( struct ggml_compute_params * params, struct ggml_tensor * tensor) { GGML_ASSERT(params); if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { @@ -12726,370 +12914,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } // extra_buffer op? - if (ggml_cpu_extra_compute_forward(params, tensor)) return; + if (ggml_cpu_extra_compute_forward(params, tensor)) + return; - switch (tensor->op) { - case GGML_OP_DUP: - { - ggml_compute_forward_dup(params, tensor); - } break; - case GGML_OP_ADD: - { - ggml_compute_forward_add(params, tensor); - } break; - case GGML_OP_ADD1: - { - ggml_compute_forward_add1(params, tensor); - } break; - case GGML_OP_ACC: - { - ggml_compute_forward_acc(params, tensor); - } break; - case GGML_OP_SUB: - { - ggml_compute_forward_sub(params, tensor); - } break; - case GGML_OP_MUL: - { - ggml_compute_forward_mul(params, tensor); - } break; - case GGML_OP_DIV: - { - ggml_compute_forward_div(params, tensor); - } break; - case GGML_OP_SQR: - { - ggml_compute_forward_sqr(params, tensor); - } break; - case GGML_OP_SQRT: - { - ggml_compute_forward_sqrt(params, tensor); - } break; - case GGML_OP_LOG: - { - ggml_compute_forward_log(params, tensor); - } break; - case GGML_OP_SIN: - { - ggml_compute_forward_sin(params, tensor); - } break; - case GGML_OP_COS: - { - ggml_compute_forward_cos(params, tensor); - } break; - case GGML_OP_SUM: - { - ggml_compute_forward_sum(params, tensor); - } break; - case GGML_OP_SUM_ROWS: - { - ggml_compute_forward_sum_rows(params, tensor); - } break; - case GGML_OP_MEAN: - { - ggml_compute_forward_mean(params, tensor); - } break; - case GGML_OP_ARGMAX: - { - ggml_compute_forward_argmax(params, tensor); - } break; - case GGML_OP_COUNT_EQUAL: - { - ggml_compute_forward_count_equal(params, tensor); - } break; - case GGML_OP_REPEAT: - { - ggml_compute_forward_repeat(params, tensor); - } break; - case GGML_OP_REPEAT_BACK: - { - ggml_compute_forward_repeat_back(params, tensor); - } break; - case GGML_OP_CONCAT: - { - ggml_compute_forward_concat(params, tensor); - } break; - case GGML_OP_SILU_BACK: - { - ggml_compute_forward_silu_back(params, tensor); - } break; - case GGML_OP_NORM: - { - ggml_compute_forward_norm(params, tensor); - } break; - case GGML_OP_RMS_NORM: - { - ggml_compute_forward_rms_norm(params, tensor); - } break; - case GGML_OP_RMS_NORM_BACK: - { - ggml_compute_forward_rms_norm_back(params, tensor); - } break; - case GGML_OP_GROUP_NORM: - { - ggml_compute_forward_group_norm(params, tensor); - } break; - case GGML_OP_MUL_MAT: - { - ggml_compute_forward_mul_mat(params, tensor); - } break; - case GGML_OP_MUL_MAT_ID: - { - ggml_compute_forward_mul_mat_id(params, tensor); - } break; - case GGML_OP_OUT_PROD: - { - ggml_compute_forward_out_prod(params, tensor); - } break; - case GGML_OP_SCALE: - { - ggml_compute_forward_scale(params, tensor); - } break; - case GGML_OP_SET: - { - ggml_compute_forward_set(params, tensor); - } break; - case GGML_OP_CPY: - { - ggml_compute_forward_cpy(params, tensor); - } break; - case GGML_OP_CONT: - { - ggml_compute_forward_cont(params, tensor); - } break; - case GGML_OP_RESHAPE: - { - ggml_compute_forward_reshape(params, tensor); - } break; - case GGML_OP_VIEW: - { - ggml_compute_forward_view(params, tensor); - } break; - case GGML_OP_PERMUTE: - { - ggml_compute_forward_permute(params, tensor); - } break; - case GGML_OP_TRANSPOSE: - { - ggml_compute_forward_transpose(params, tensor); - } break; - case GGML_OP_GET_ROWS: - { - ggml_compute_forward_get_rows(params, tensor); - } break; - case GGML_OP_GET_ROWS_BACK: - { - ggml_compute_forward_get_rows_back(params, tensor); - } break; - case GGML_OP_DIAG: - { - ggml_compute_forward_diag(params, tensor); - } break; - case GGML_OP_DIAG_MASK_INF: - { - ggml_compute_forward_diag_mask_inf(params, tensor); - } break; - case GGML_OP_DIAG_MASK_ZERO: - { - ggml_compute_forward_diag_mask_zero(params, tensor); - } break; - case GGML_OP_SOFT_MAX: - { - ggml_compute_forward_soft_max(params, tensor); - } break; - case GGML_OP_SOFT_MAX_BACK: - { - ggml_compute_forward_soft_max_ext_back(params, tensor); - } break; - case GGML_OP_ROPE: - { - ggml_compute_forward_rope(params, tensor); - } break; - case GGML_OP_ROPE_BACK: - { - ggml_compute_forward_rope_back(params, tensor); - } break; - case GGML_OP_CLAMP: - { - ggml_compute_forward_clamp(params, tensor); - } break; - case GGML_OP_CONV_TRANSPOSE_1D: - { - ggml_compute_forward_conv_transpose_1d(params, tensor); - } break; - case GGML_OP_IM2COL: - { - ggml_compute_forward_im2col(params, tensor); - } break; - case GGML_OP_IM2COL_BACK: - { - ggml_compute_forward_im2col_back_f32(params, tensor); - } break; - case GGML_OP_CONV_TRANSPOSE_2D: - { - ggml_compute_forward_conv_transpose_2d(params, tensor); - } break; - case GGML_OP_POOL_1D: - { - ggml_compute_forward_pool_1d(params, tensor); - } break; - case GGML_OP_POOL_2D: - { - ggml_compute_forward_pool_2d(params, tensor); - } break; - case GGML_OP_POOL_2D_BACK: - { - ggml_compute_forward_pool_2d_back(params, tensor); - } break; - case GGML_OP_UPSCALE: - { - ggml_compute_forward_upscale(params, tensor); - } break; - case GGML_OP_PAD: - { - ggml_compute_forward_pad(params, tensor); - } break; - case GGML_OP_PAD_REFLECT_1D: - { - ggml_compute_forward_pad_reflect_1d(params, tensor); - } break; - case GGML_OP_ARANGE: - { - ggml_compute_forward_arange(params, tensor); - } break; - case GGML_OP_TIMESTEP_EMBEDDING: - { - ggml_compute_forward_timestep_embedding(params, tensor); - } break; - case GGML_OP_ARGSORT: - { - ggml_compute_forward_argsort(params, tensor); - } break; - case GGML_OP_LEAKY_RELU: - { - ggml_compute_forward_leaky_relu(params, tensor); - } break; - case GGML_OP_FLASH_ATTN_EXT: - { - ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); - } break; - case GGML_OP_FLASH_ATTN_BACK: - { - int32_t t = ggml_get_op_params_i32(tensor, 0); - GGML_ASSERT(t == 0 || t == 1); - bool masked = t != 0; - ggml_compute_forward_flash_attn_back(params, masked, tensor); - } break; - case GGML_OP_SSM_CONV: - { - ggml_compute_forward_ssm_conv(params, tensor); - } break; - case GGML_OP_SSM_SCAN: - { - ggml_compute_forward_ssm_scan(params, tensor); - } break; - case GGML_OP_WIN_PART: - { - ggml_compute_forward_win_part(params, tensor); - } break; - case GGML_OP_WIN_UNPART: - { - ggml_compute_forward_win_unpart(params, tensor); - } break; - case GGML_OP_UNARY: - { - ggml_compute_forward_unary(params, tensor); - } break; - case GGML_OP_GET_REL_POS: - { - ggml_compute_forward_get_rel_pos(params, tensor); - } break; - case GGML_OP_ADD_REL_POS: - { - ggml_compute_forward_add_rel_pos(params, tensor); - } break; - case GGML_OP_RWKV_WKV6: - { - ggml_compute_forward_rwkv_wkv6(params, tensor); - } break; - case GGML_OP_GATED_LINEAR_ATTN: - { - ggml_compute_forward_gla(params, tensor); - } break; - case GGML_OP_MAP_UNARY: - { - ggml_unary_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_unary(params, tensor, fun); - } - break; - case GGML_OP_MAP_BINARY: - { - ggml_binary_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_binary(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM1_F32: - { - ggml_custom1_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom1_f32(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM2_F32: - { - ggml_custom2_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom2_f32(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM3_F32: - { - ggml_custom3_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom3_f32(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM1: - { - ggml_compute_forward_map_custom1(params, tensor); - } - break; - case GGML_OP_MAP_CUSTOM2: - { - ggml_compute_forward_map_custom2(params, tensor); - } - break; - case GGML_OP_MAP_CUSTOM3: - { - ggml_compute_forward_map_custom3(params, tensor); - } - break; - case GGML_OP_CROSS_ENTROPY_LOSS: - { - ggml_compute_forward_cross_entropy_loss(params, tensor); - } - break; - case GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - ggml_compute_forward_cross_entropy_loss_back(params, tensor); - } - break; - case GGML_OP_OPT_STEP_ADAMW: - { - ggml_compute_forward_opt_step_adamw(params, tensor); - } - break; - case GGML_OP_NONE: - { - // nop - } break; - case GGML_OP_COUNT: - { - GGML_ABORT("fatal error"); - } + if ((size_t)(tensor->op) > GGML_OP_COUNT) + { + ggml_compute_forward_error(params, tensor); + return; } + + ggml_op_functions[tensor->op](params, tensor); } // Android's libc implementation "bionic" does not support setting affinity