From 50b5e90112766dc4de276ccb0d0abf0f9a974b84 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 2 Oct 2024 17:05:56 +0300 Subject: [PATCH] Fused unary(x)*y (#70) * Adding fused y*unary(x) op * Fused y*unary(x) op: CUDA * Fused y*unary(x) op: dedicated CPU implementation for silu and gelu * Fused y*unary(x) op: Metal --------- Co-authored-by: Iwan Kawrakow --- ggml/include/ggml.h | 13 +++ ggml/src/ggml-cuda.cu | 4 + ggml/src/ggml-cuda/unary.cu | 67 +++++++++++++ ggml/src/ggml-cuda/unary.cuh | 2 + ggml/src/ggml-metal.m | 32 ++++++ ggml/src/ggml-metal.metal | 50 +++++++++ ggml/src/ggml.c | 189 ++++++++++++++++++++++++++++++++++- src/llama.cpp | 8 ++ 8 files changed, 363 insertions(+), 2 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 08fe6a3e0ebc8..b1aebd2159ab5 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -487,6 +487,7 @@ extern "C" { GGML_OP_RMS_NORM_BACK, GGML_OP_GROUP_NORM, GGML_OP_FUSED_RMS_NORM, + GGML_OP_FUSED_MUL_UNARY, GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID, @@ -963,6 +964,18 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_fused_mul_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op); + + GGML_API struct ggml_tensor * ggml_fused_mul_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op); + GGML_API struct ggml_tensor * ggml_div( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 64cc759229c11..871d4007dc85c 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2222,6 +2222,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_MUL: ggml_cuda_op_mul(ctx, dst); break; + case GGML_OP_FUSED_MUL_UNARY: + ggml_cuda_op_fused_mul_unary(ctx, dst); + break; case GGML_OP_DIV: ggml_cuda_op_div(ctx, dst); break; @@ -2788,6 +2791,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons return false; } break; + case GGML_OP_FUSED_MUL_UNARY: return ggml_is_contiguous(op->src[0]); case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 51582ed530e36..7bc43d0f45a70 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -43,6 +43,36 @@ static __global__ void swiglu_f32(const float * x, float * dst, const int k, con dst[i] = x[j] * x[j + ne0] / (1.0f + expf(-x[j])); } +static __global__ void fused_mul_silu_f32(const float * x, const float * y, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = x[i] * y[i] / (1.0f + expf(-x[i])); +} + +static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = fmaxf(x[i], 0) * y[i]; +} + +static __global__ void fused_mul_gelu_f32(const float * x, const float * y, float * dst, const int k) { + constexpr float GELU_COEF_A = 0.044715f; + constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + float xi = x[i]; + dst[i] = 0.5f*xi*y[i]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); +} + static __global__ void tanh_f32(const float * x, float * dst, int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -133,6 +163,21 @@ static void swiglu_f32_cuda(const float * x, float * dst, const int k, const int swiglu_f32<<>>(x, dst, k, ne0, nb1); } +static void fused_mul_silu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + fused_mul_silu_f32<<>>(x, y, dst, k); +} + +static void fused_mul_relu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; + fused_mul_relu_f32<<>>(x, y, dst, k); +} + +static void fused_mul_gelu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; + fused_mul_gelu_f32<<>>(x, y, dst, k); +} + static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; tanh_f32<<>>(x, dst, k); @@ -216,6 +261,28 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { swiglu_f32_cuda(src0_d, dst_d, ggml_nelements(dst), dst->ne[0], src0->nb[1]/sizeof(float), stream); } +void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + + cudaStream_t stream = ctx.stream(); + ggml_unary_op op = (ggml_unary_op)dst->op_params[0]; + + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + + switch (op) { + case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + default: GGML_ASSERT(false); + } +} + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index be3d6f1561076..d2d478b49c53e 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -33,3 +33,5 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index dcdd0efe1a114..4badc7a75f7d1 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -56,13 +56,18 @@ GGML_METAL_KERNEL_TYPE_CLAMP, GGML_METAL_KERNEL_TYPE_TANH, GGML_METAL_KERNEL_TYPE_RELU, + GGML_METAL_KERNEL_TYPE_MUL_RELU, GGML_METAL_KERNEL_TYPE_SIGMOID, GGML_METAL_KERNEL_TYPE_GELU, GGML_METAL_KERNEL_TYPE_GELU_4, + GGML_METAL_KERNEL_TYPE_MUL_GELU, + GGML_METAL_KERNEL_TYPE_MUL_GELU_4, GGML_METAL_KERNEL_TYPE_GELU_QUICK, GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, GGML_METAL_KERNEL_TYPE_SILU_4, + GGML_METAL_KERNEL_TYPE_MUL_SILU, + GGML_METAL_KERNEL_TYPE_MUL_SILU_4, GGML_METAL_KERNEL_TYPE_SWIGLU, GGML_METAL_KERNEL_TYPE_SWIGLU_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, @@ -584,13 +589,18 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_RELU, mul_relu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_GELU, mul_gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_GELU_4, mul_gelu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_SILU, mul_silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_SILU_4, mul_silu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_4, swiglu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction); @@ -921,6 +931,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_OP_SQR: case GGML_OP_SUM_ROWS: return true; + case GGML_OP_FUSED_MUL_UNARY: + return ggml_is_contiguous(op->src[0]); case GGML_OP_SOFTCAP: case GGML_OP_SOFT_CAP_MAX: return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op); @@ -1648,6 +1660,26 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ABORT("fatal error"); } } break; + case GGML_OP_FUSED_MUL_UNARY: + { + int64_t n = ggml_nelements(dst); + enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0]; + id pipeline = nil; + if (n % 4 == 0 && op != GGML_UNARY_OP_RELU) { + pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU_4].pipeline + : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU_4].pipeline; + n /= 4; + } else { + pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU].pipeline + : op == GGML_UNARY_OP_SILU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU].pipeline + : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_RELU].pipeline; + } + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_SQR: { GGML_ASSERT(ggml_is_contiguous(src0)); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 225fa5f1cbf04..4dbfa089ff1f6 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -323,6 +323,14 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } +kernel void kernel_mul_relu( + device const float * src0, + device const float * src1, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]) * src1[tpig]; +} + kernel void kernel_sigmoid( device const float * src0, device float * dst, @@ -364,6 +372,30 @@ kernel void kernel_gelu_4( dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } +kernel void kernel_mul_gelu( + device const float * src0, + device const float * src1, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*src1[tpig]*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_mul_gelu_4( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*src1[tpig]*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + kernel void kernel_gelu_quick( device const float * src0, device float * dst, @@ -398,6 +430,24 @@ kernel void kernel_silu_4( dst[tpig] = x / (1.0f + exp(-x)); } +kernel void kernel_mul_silu( + device const float * src0, + device const float * src1, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x * src1[tpig] / (1.0f + exp(-x)); +} + +kernel void kernel_mul_silu_4( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = x * src1[tpig] / (1.0f + exp(-x)); +} + kernel void kernel_swiglu( device const float * src0, device float * dst, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d31713dfc228c..08eab23b4d602 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2888,6 +2888,30 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) { } } +static void ggml_vec_mul_silu_f32(const int n, float * z, const float * x, const float * y) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(z + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(y + i))); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(z + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(y + i))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(z + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(y + i))); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(z + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(y + i))); + } +#endif + for (; i < n; ++i) { + z[i] = ggml_silu_f32(x[i]) * y[i]; + } +} + static void ggml_vec_swiglu_f32(const int n, float * y, const float * x) { int i = 0; #if defined(__AVX512F__) && defined(__AVX512DQ__) @@ -3100,6 +3124,47 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { } #endif } +inline static void ggml_vec_mul_gelu_f32(const int n, float * z, const float * x, const float * y) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + __m512 c1 = _mm512_set1_ps(GELU_COEF_A); + __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI); + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(z + i, _mm512_mul_ps(ggml_v_gelu(_mm512_loadu_ps(x + i), c1, c2), _mm512_loadu_ps(y + i))); + } +#elif defined __AVX2__ && defined __FMA__ + __m256 c1 = _mm256_set1_ps(GELU_COEF_A); + __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI); + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(z + i, _mm256_mul_ps(ggml_v_gelu(_mm256_loadu_ps(x + i), c1, c2), _mm256_loadu_ps(y + i))); + } +#endif +#ifdef GGML_GELU_FP16 + uint16_t t; + for (; i < n; ++i) { + if (x[i] <= -10.0f) { + z[i] = 0.0f; + } else if (x[i] >= 10.0f) { + z[i] = x[i]*y[i]; + } else { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + z[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t])*y[i]; + } + } +#else +#if defined __ARM_NEON + float32x4_t c1 = vdupq_n_f32(GELU_COEF_A); + float32x4_t c2 = vdupq_n_f32(2.f*SQRT_2_OVER_PI); + for (; i + 3 < n; i += 4) { + vst1q_f32(z + i, vmulq_f32(ggml_v_gelu(vld1q_f32(x + i), c1, c2), vld1q_f32(y + i))); + } +#endif + for (; i < n; ++i) { + z[i] = ggml_gelu_f32(x[i])*y[i]; + } +#endif +} static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; @@ -3258,6 +3323,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "RMS_NORM_BACK", "GROUP_NORM", "FUSED_RMS_NORM", + "FUSED_MUL_UNARY", "MUL_MAT", "MUL_MAT_ID", @@ -3321,7 +3387,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); +static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3349,6 +3415,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rms_norm_back(x)", "group_norm(x)", "fused_rms_norm(x)", + "fused_mul_unary(x)", "X*Y", "X[i]*Y", @@ -3412,7 +3479,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); +static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5246,6 +5313,55 @@ struct ggml_tensor * ggml_mul_inplace( struct ggml_tensor * b) { return ggml_mul_impl(ctx, a, b, true); } +// ggml_mul + +static struct ggml_tensor * ggml_fused_mul_unary_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op, + bool inplace) { + GGML_ASSERT(ggml_are_same_shape(b, a)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + if (inplace) { + GGML_ASSERT(!is_node); + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + result->op = GGML_OP_FUSED_MUL_UNARY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_fused_mul_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op) { + return ggml_fused_mul_unary_impl(ctx, a, b, op, false); +} + +struct ggml_tensor * ggml_fused_mul_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op) { + return ggml_fused_mul_unary_impl(ctx, a, b, op, true); +} // ggml_div @@ -12374,6 +12490,66 @@ static void ggml_compute_forward_swiglu( } } +// ggml_compute_forward_fused_mul_unary + +static void ggml_compute_forward_fused_mul_unary_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * z = (float *) ((char *) dst->data + i1*( dst->nb[1])); + const float * x = (const float *) ((char *) src0->data + i1*(src0->nb[1])); + const float * y = (const float *) ((char *) src1->data + i1*(src1->nb[1])); + switch (op) { + case GGML_UNARY_OP_GELU: ggml_vec_gelu_f32(nc, z, x); ggml_vec_mul_f32(nc, z, z, y); break; + case GGML_UNARY_OP_RELU: ggml_vec_relu_f32(nc, z, x); ggml_vec_mul_f32(nc, z, z, y); break; + case GGML_UNARY_OP_SILU: ggml_vec_mul_silu_f32(nc, z, x, y); break; + default: GGML_ABORT("fatal error"); + } + } +} + +static void ggml_compute_forward_fused_mul_unary( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fused_mul_unary_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_leaky_relu static void ggml_compute_forward_leaky_relu_f32( @@ -17990,6 +18166,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_mul(params, tensor); } break; + case GGML_OP_FUSED_MUL_UNARY: + { + ggml_compute_forward_fused_mul_unary(params, tensor); + } break; case GGML_OP_DIV: { ggml_compute_forward_div(params, tensor); @@ -18715,6 +18895,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; + case GGML_OP_FUSED_MUL_UNARY: + { + GGML_ABORT("fatal error"); // TODO: implement + } case GGML_OP_CONCAT: { GGML_ABORT("fatal error"); // TODO: implement @@ -19813,6 +19997,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { break; case GGML_OP_SILU_BACK: case GGML_OP_MUL: + case GGML_OP_FUSED_MUL_UNARY: case GGML_OP_DIV: case GGML_OP_NORM: case GGML_OP_RMS_NORM: diff --git a/src/llama.cpp b/src/llama.cpp index eb9821258dfa1..9ed109c6725a2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8083,6 +8083,13 @@ static struct ggml_tensor * llm_build_ffn( cur = tmp; } + if (type_gate == LLM_FFN_PAR && + (type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) { + cur = ggml_fused_mul_unary(ctx, cur, tmp, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : + type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU); + } + else { + switch (type_op) { case LLM_FFN_SILU: { @@ -8122,6 +8129,7 @@ static struct ggml_tensor * llm_build_ffn( cur = ggml_mul(ctx, cur, tmp); cb(cur, "ffn_gate_par", il); } + } if (down) { cur = llm_build_lora_mm(lctx, ctx, down, cur);