Skip to content

Commit

Permalink
Fused unary(x)*y (#70)
Browse files Browse the repository at this point in the history
* Adding fused y*unary(x) op

* Fused y*unary(x) op: CUDA

* Fused y*unary(x) op: dedicated CPU implementation for silu and gelu

* Fused y*unary(x) op: Metal

---------

Co-authored-by: Iwan Kawrakow <[email protected]>
  • Loading branch information
ikawrakow and Kawrakow authored Oct 2, 2024
1 parent cce4983 commit 50b5e90
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 2 deletions.
13 changes: 13 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ extern "C" {
GGML_OP_RMS_NORM_BACK,
GGML_OP_GROUP_NORM,
GGML_OP_FUSED_RMS_NORM,
GGML_OP_FUSED_MUL_UNARY,

GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
Expand Down Expand Up @@ -963,6 +964,18 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);

GGML_API struct ggml_tensor * ggml_fused_mul_unary(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
enum ggml_unary_op op);

GGML_API struct ggml_tensor * ggml_fused_mul_unary_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
enum ggml_unary_op op);

GGML_API struct ggml_tensor * ggml_div(
struct ggml_context * ctx,
struct ggml_tensor * a,
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2222,6 +2222,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_MUL:
ggml_cuda_op_mul(ctx, dst);
break;
case GGML_OP_FUSED_MUL_UNARY:
ggml_cuda_op_fused_mul_unary(ctx, dst);
break;
case GGML_OP_DIV:
ggml_cuda_op_div(ctx, dst);
break;
Expand Down Expand Up @@ -2788,6 +2791,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
return false;
}
break;
case GGML_OP_FUSED_MUL_UNARY: return ggml_is_contiguous(op->src[0]);
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
Expand Down
67 changes: 67 additions & 0 deletions ggml/src/ggml-cuda/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,36 @@ static __global__ void swiglu_f32(const float * x, float * dst, const int k, con
dst[i] = x[j] * x[j + ne0] / (1.0f + expf(-x[j]));
}

static __global__ void fused_mul_silu_f32(const float * x, const float * y, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}
dst[i] = x[i] * y[i] / (1.0f + expf(-x[i]));
}

static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}
dst[i] = fmaxf(x[i], 0) * y[i];
}

static __global__ void fused_mul_gelu_f32(const float * x, const float * y, float * dst, const int k) {
constexpr float GELU_COEF_A = 0.044715f;
constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}
float xi = x[i];
dst[i] = 0.5f*xi*y[i]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
}

static __global__ void tanh_f32(const float * x, float * dst, int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
Expand Down Expand Up @@ -133,6 +163,21 @@ static void swiglu_f32_cuda(const float * x, float * dst, const int k, const int
swiglu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k, ne0, nb1);
}

static void fused_mul_silu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
fused_mul_silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
}

static void fused_mul_relu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
fused_mul_relu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
}

static void fused_mul_gelu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
fused_mul_gelu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
}

static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
Expand Down Expand Up @@ -216,6 +261,28 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
swiglu_f32_cuda(src0_d, dst_d, ggml_nelements(dst), dst->ne[0], src0->nb[1]/sizeof(float), stream);
}

void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src0, src1));

cudaStream_t stream = ctx.stream();
ggml_unary_op op = (ggml_unary_op)dst->op_params[0];

const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;

switch (op) {
case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
default: GGML_ASSERT(false);
}
}

void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/unary.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
32 changes: 32 additions & 0 deletions ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,18 @@
GGML_METAL_KERNEL_TYPE_CLAMP,
GGML_METAL_KERNEL_TYPE_TANH,
GGML_METAL_KERNEL_TYPE_RELU,
GGML_METAL_KERNEL_TYPE_MUL_RELU,
GGML_METAL_KERNEL_TYPE_SIGMOID,
GGML_METAL_KERNEL_TYPE_GELU,
GGML_METAL_KERNEL_TYPE_GELU_4,
GGML_METAL_KERNEL_TYPE_MUL_GELU,
GGML_METAL_KERNEL_TYPE_MUL_GELU_4,
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
GGML_METAL_KERNEL_TYPE_SILU,
GGML_METAL_KERNEL_TYPE_SILU_4,
GGML_METAL_KERNEL_TYPE_MUL_SILU,
GGML_METAL_KERNEL_TYPE_MUL_SILU_4,
GGML_METAL_KERNEL_TYPE_SWIGLU,
GGML_METAL_KERNEL_TYPE_SWIGLU_4,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
Expand Down Expand Up @@ -584,13 +589,18 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_RELU, mul_relu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_GELU, mul_gelu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_GELU_4, mul_gelu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_SILU, mul_silu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_SILU_4, mul_silu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_4, swiglu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
Expand Down Expand Up @@ -921,6 +931,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
case GGML_OP_SQR:
case GGML_OP_SUM_ROWS:
return true;
case GGML_OP_FUSED_MUL_UNARY:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_SOFTCAP:
case GGML_OP_SOFT_CAP_MAX:
return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op);
Expand Down Expand Up @@ -1648,6 +1660,26 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ABORT("fatal error");
}
} break;
case GGML_OP_FUSED_MUL_UNARY:
{
int64_t n = ggml_nelements(dst);
enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0];
id<MTLComputePipelineState> pipeline = nil;
if (n % 4 == 0 && op != GGML_UNARY_OP_RELU) {
pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU_4].pipeline
: ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU_4].pipeline;
n /= 4;
} else {
pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU].pipeline
: op == GGML_UNARY_OP_SILU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU].pipeline
: ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_RELU].pipeline;
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SQR:
{
GGML_ASSERT(ggml_is_contiguous(src0));
Expand Down
50 changes: 50 additions & 0 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,14 @@ kernel void kernel_relu(
dst[tpig] = max(0.0f, src0[tpig]);
}

kernel void kernel_mul_relu(
device const float * src0,
device const float * src1,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = max(0.0f, src0[tpig]) * src1[tpig];
}

kernel void kernel_sigmoid(
device const float * src0,
device float * dst,
Expand Down Expand Up @@ -364,6 +372,30 @@ kernel void kernel_gelu_4(
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}

kernel void kernel_mul_gelu(
device const float * src0,
device const float * src1,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];

dst[tpig] = 0.5f*x*src1[tpig]*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}

kernel void kernel_mul_gelu_4(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];

// BEWARE !!!
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
// This was observed with Falcon 7B and 40B models
//
dst[tpig] = 0.5f*x*src1[tpig]*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}

kernel void kernel_gelu_quick(
device const float * src0,
device float * dst,
Expand Down Expand Up @@ -398,6 +430,24 @@ kernel void kernel_silu_4(
dst[tpig] = x / (1.0f + exp(-x));
}

kernel void kernel_mul_silu(
device const float * src0,
device const float * src1,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = x * src1[tpig] / (1.0f + exp(-x));
}

kernel void kernel_mul_silu_4(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = x * src1[tpig] / (1.0f + exp(-x));
}

kernel void kernel_swiglu(
device const float * src0,
device float * dst,
Expand Down
Loading

0 comments on commit 50b5e90

Please sign in to comment.