diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5bbaae0383f0b..b847a85a85401 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -509,6 +509,7 @@ extern "C" { GGML_OP_RMS_NORM_BACK, GGML_OP_GROUP_NORM, GGML_OP_FUSED_RMS_NORM, + GGML_OP_FUSED_MUL_UNARY, GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID, @@ -908,6 +909,18 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_fused_mul_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op); + + GGML_API struct ggml_tensor * ggml_fused_mul_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op); + GGML_API struct ggml_tensor * ggml_div( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 85ac50e978af7..19f120648c141 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -46,6 +46,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) { case GGML_OP_ADD1: case GGML_OP_SUB: case GGML_OP_MUL: + case GGML_OP_FUSED_MUL_UNARY: case GGML_OP_DIV: case GGML_OP_SQR: case GGML_OP_SQRT: diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 852c72ebfbfc6..06a87d2b6e37b 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2457,6 +2457,30 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) { } } +static void ggml_vec_mul_silu_f32(const int n, float * z, const float * x, const float * y) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(z + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(y + i))); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(z + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(y + i))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(z + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(y + i))); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(z + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(y + i))); + } +#endif + for (; i < n; ++i) { + z[i] = ggml_silu_f32(x[i]) * y[i]; + } +} + static void ggml_vec_tanh_f32(const int n, float * y, const float * x) { int i = 0; #if defined(__AVX512F__) && defined(__AVX512DQ__) @@ -2713,6 +2737,48 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { #endif } +inline static void ggml_vec_mul_gelu_f32(const int n, float * z, const float * x, const float * y) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + __m512 c1 = _mm512_set1_ps(GELU_COEF_A); + __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI); + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(z + i, _mm512_mul_ps(ggml_v_gelu(_mm512_loadu_ps(x + i), c1, c2), _mm512_loadu_ps(y + i))); + } +#elif defined __AVX2__ && defined __FMA__ + __m256 c1 = _mm256_set1_ps(GELU_COEF_A); + __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI); + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(z + i, _mm256_mul_ps(ggml_v_gelu(_mm256_loadu_ps(x + i), c1, c2), _mm256_loadu_ps(y + i))); + } +#endif +#ifdef GGML_GELU_FP16 + uint16_t t; + for (; i < n; ++i) { + if (x[i] <= -10.0f) { + z[i] = 0.0f; + } else if (x[i] >= 10.0f) { + z[i] = x[i]*y[i]; + } else { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + z[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t])*y[i]; + } + } +#else +#if defined __ARM_NEON + float32x4_t c1 = vdupq_n_f32(GELU_COEF_A); + float32x4_t c2 = vdupq_n_f32(2.f*SQRT_2_OVER_PI); + for (; i + 3 < n; i += 4) { + vst1q_f32(z + i, vmulq_f32(ggml_v_gelu(vld1q_f32(x + i), c1, c2), vld1q_f32(y + i))); + } +#endif + for (; i < n; ++i) { + z[i] = ggml_gelu_f32(x[i])*y[i]; + } +#endif +} + static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; ggml_float sum = 0; @@ -7380,6 +7446,67 @@ static void ggml_compute_forward_silu( } } } + +// ggml_compute_forward_fused_mul_unary + +static void ggml_compute_forward_fused_mul_unary_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0]; + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_are_same_shape(src0, dst)); + assert(ggml_are_same_shape(src0, src1)); + assert(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * z = (float *) ((char *) dst->data + i1*( dst->nb[1])); + const float * x = (const float *) ((char *) src0->data + i1*(src0->nb[1])); + const float * y = (const float *) ((char *) src1->data + i1*(src1->nb[1])); + switch (op) { + case GGML_UNARY_OP_GELU: ggml_vec_gelu_f32(nc, z, x); ggml_vec_mul_f32(nc, z, z, y); break; + case GGML_UNARY_OP_RELU: ggml_vec_relu_f32(nc, z, x); ggml_vec_mul_f32(nc, z, z, y); break; + case GGML_UNARY_OP_SILU: ggml_vec_mul_silu_f32(nc, z, x, y); break; + default: GGML_ABORT("fatal error"); + } + } +} + +static void ggml_compute_forward_fused_mul_unary( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fused_mul_unary_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_leaky_relu static void ggml_compute_forward_leaky_relu_f32( @@ -13788,6 +13915,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { break; case GGML_OP_SILU_BACK: case GGML_OP_MUL: + case GGML_OP_FUSED_MUL_UNARY: case GGML_OP_DIV: case GGML_OP_NORM: case GGML_OP_RMS_NORM: diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 12cac32db4d4c..81f87da53d182 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2257,6 +2257,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_MUL: ggml_cuda_op_mul(ctx, dst); break; + case GGML_OP_FUSED_MUL_UNARY: + ggml_cuda_op_fused_mul_unary(ctx, dst); + break; case GGML_OP_DIV: ggml_cuda_op_div(ctx, dst); break; @@ -3057,6 +3060,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; } break; + case GGML_OP_FUSED_MUL_UNARY: return ggml_is_contiguous(op->src[0]); case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 81fc92202f25a..c95af84c74f59 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -51,6 +51,36 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) { dst[i] = x[i] / (1.0f + expf(-x[i])); } +static __global__ void fused_mul_silu_f32(const float * x, const float * y, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = x[i] * y[i] / (1.0f + expf(-x[i])); +} + +static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = fmaxf(x[i], 0) * y[i]; +} + +static __global__ void fused_mul_gelu_f32(const float * x, const float * y, float * dst, const int k) { + constexpr float GELU_COEF_A = 0.044715f; + constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + float xi = x[i]; + dst[i] = 0.5f*xi*y[i]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); +} + static __global__ void tanh_f32(const float * x, float * dst, int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -173,6 +203,21 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ silu_f32<<>>(x, dst, k); } +static void fused_mul_silu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + fused_mul_silu_f32<<>>(x, y, dst, k); +} + +static void fused_mul_relu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; + fused_mul_relu_f32<<>>(x, y, dst, k); +} + +static void fused_mul_gelu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; + fused_mul_gelu_f32<<>>(x, y, dst, k); +} + static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; tanh_f32<<>>(x, dst, k); @@ -284,6 +329,28 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } +void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + + cudaStream_t stream = ctx.stream(); + ggml_unary_op op = (ggml_unary_op)dst->op_params[0]; + + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + + switch (op) { + case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + default: GGML_ASSERT(false); + } +} + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index c91936728bab1..a09bc77035921 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -43,6 +43,8 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 4285833dd3f79..8336dd0113ffa 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1017,12 +1017,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_SOFTCAP: return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op); case GGML_OP_SOFT_MAX: - case GGML_OP_RMS_NORM: - case GGML_OP_FUSED_RMS_NORM: case GGML_OP_GROUP_NORM: return has_simdgroup_reduction; case GGML_OP_RMS_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0); + case GGML_OP_FUSED_RMS_NORM: case GGML_OP_NORM: case GGML_OP_ROPE: return true; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index eae5b29158a96..e1dcb6d30554c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1090,6 +1090,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "RMS_NORM_BACK", "GROUP_NORM", "FUSED_RMS_NORM", + "FUSED_MUL_UNARY", "MUL_MAT", "MUL_MAT_ID", @@ -1156,7 +1157,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); +static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1187,6 +1188,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rms_norm_back(x)", "group_norm(x)", "fused_rms_norm(x)", + "fused_mul_unary(x)", "X*Y", "X[i]*Y", @@ -1253,7 +1255,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "adamw(x)", }; -static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); +static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -2244,6 +2246,56 @@ struct ggml_tensor * ggml_mul_inplace( return ggml_mul_impl(ctx, a, b, true); } +// ggml_fused_mul + +static struct ggml_tensor * ggml_fused_mul_unary_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op, + bool inplace) { + assert(ggml_are_same_shape(b, a)); + assert(ggml_is_contiguous(a)); + assert(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU); + + // bool is_node = false; + + // if (!inplace && (a->grad || b->grad)) { + // is_node = true; + // } + + // if (inplace) { + // GGML_ASSERT(!is_node); + // } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + result->op = GGML_OP_FUSED_MUL_UNARY; + // result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_fused_mul_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op) { + return ggml_fused_mul_unary_impl(ctx, a, b, op, false); +} + +struct ggml_tensor * ggml_fused_mul_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op) { + return ggml_fused_mul_unary_impl(ctx, a, b, op, true); +} + // ggml_div static struct ggml_tensor * ggml_div_impl( diff --git a/src/llama.cpp b/src/llama.cpp index 32ac04e7d9016..ec4f829f60c78 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9818,52 +9818,60 @@ static struct ggml_tensor * llm_build_ffn( cur = tmp; } - switch (type_op) { - case LLM_FFN_SILU: - { - cur = ggml_silu(ctx, cur); - cb(cur, "ffn_silu", il); - } break; - case LLM_FFN_GELU: - { - cur = ggml_gelu(ctx, cur); - cb(cur, "ffn_gelu", il); - if (act_scales != NULL) { - cur = ggml_div(ctx, cur, act_scales); - cb(cur, "ffn_act", il); - } - } break; - case LLM_FFN_RELU: - { - cur = ggml_relu(ctx, cur); - cb(cur, "ffn_relu", il); - } break; - case LLM_FFN_RELU_SQR: - { - cur = ggml_relu(ctx, cur); - cb(cur, "ffn_relu", il); - - cur = ggml_sqr(ctx, cur); - cb(cur, "ffn_sqr(relu)", il); - } break; - case LLM_FFN_SWIGLU: - { - // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - int64_t split_point = cur->ne[0] / 2; - struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0)); - struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); - - x0 = ggml_silu(ctx, x0); - cb(cur, "ffn_silu", il); - - cur = ggml_mul(ctx, x0, x1); - cb(cur, "ffn_mul", il); - } break; + if (type_gate == LLM_FFN_PAR && + (type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) { + cur = ggml_fused_mul_unary(ctx, cur, tmp, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : + type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU); } + else { - if (type_gate == LLM_FFN_PAR) { - cur = ggml_mul(ctx, cur, tmp); - cb(cur, "ffn_gate_par", il); + switch (type_op) { + case LLM_FFN_SILU: + { + cur = ggml_silu(ctx, cur); + cb(cur, "ffn_silu", il); + } break; + case LLM_FFN_GELU: + { + cur = ggml_gelu(ctx, cur); + cb(cur, "ffn_gelu", il); + if (act_scales != NULL) { + cur = ggml_div(ctx, cur, act_scales); + cb(cur, "ffn_act", il); + } + } break; + case LLM_FFN_RELU: + { + cur = ggml_relu(ctx, cur); + cb(cur, "ffn_relu", il); + } break; + case LLM_FFN_RELU_SQR: + { + cur = ggml_relu(ctx, cur); + cb(cur, "ffn_relu", il); + + cur = ggml_sqr(ctx, cur); + cb(cur, "ffn_sqr(relu)", il); + } break; + case LLM_FFN_SWIGLU: + { + // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + int64_t split_point = cur->ne[0] / 2; + struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0)); + struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); + + x0 = ggml_silu(ctx, x0); + cb(cur, "ffn_silu", il); + + cur = ggml_mul(ctx, x0, x1); + cb(cur, "ffn_mul", il); + } break; + } + + if (type_gate == LLM_FFN_PAR) { + cur = ggml_mul(ctx, cur, tmp); + cb(cur, "ffn_gate_par", il); + } } if (down) { @@ -9943,6 +9951,7 @@ static struct ggml_tensor * llm_build_moe_ffn( ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(gate, "ffn_moe_gate", il); + switch (type_op) { case LLM_FFN_SILU: {