Skip to content

Commit

Permalink
Fused unary(x)*y #70
Browse files Browse the repository at this point in the history
Credit : Iwan Kawrakow @ikawrakow
  • Loading branch information
Nexesenex committed Dec 15, 2024
1 parent 78365a5 commit 4ac7ea8
Show file tree
Hide file tree
Showing 9 changed files with 323 additions and 48 deletions.
13 changes: 13 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ extern "C" {
GGML_OP_RMS_NORM_BACK,
GGML_OP_GROUP_NORM,
GGML_OP_FUSED_RMS_NORM,
GGML_OP_FUSED_MUL_UNARY,

GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
Expand Down Expand Up @@ -908,6 +909,18 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);

GGML_API struct ggml_tensor * ggml_fused_mul_unary(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
enum ggml_unary_op op);

GGML_API struct ggml_tensor * ggml_fused_mul_unary_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
enum ggml_unary_op op);

GGML_API struct ggml_tensor * ggml_div(
struct ggml_context * ctx,
struct ggml_tensor * a,
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_FUSED_MUL_UNARY:
case GGML_OP_DIV:
case GGML_OP_SQR:
case GGML_OP_SQRT:
Expand Down
128 changes: 128 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -2457,6 +2457,30 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
}
}

static void ggml_vec_mul_silu_f32(const int n, float * z, const float * x, const float * y) {
int i = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
for (; i + 15 < n; i += 16) {
_mm512_storeu_ps(z + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(y + i)));
}
#elif defined(__AVX2__) && defined(__FMA__)
for (; i + 7 < n; i += 8) {
_mm256_storeu_ps(z + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(y + i)));
}
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
_mm_storeu_ps(z + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(y + i)));
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
vst1q_f32(z + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(y + i)));
}
#endif
for (; i < n; ++i) {
z[i] = ggml_silu_f32(x[i]) * y[i];
}
}

static void ggml_vec_tanh_f32(const int n, float * y, const float * x) {
int i = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
Expand Down Expand Up @@ -2713,6 +2737,48 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
#endif
}

inline static void ggml_vec_mul_gelu_f32(const int n, float * z, const float * x, const float * y) {
int i = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
__m512 c1 = _mm512_set1_ps(GELU_COEF_A);
__m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI);
for (; i + 15 < n; i += 16) {
_mm512_storeu_ps(z + i, _mm512_mul_ps(ggml_v_gelu(_mm512_loadu_ps(x + i), c1, c2), _mm512_loadu_ps(y + i)));
}
#elif defined __AVX2__ && defined __FMA__
__m256 c1 = _mm256_set1_ps(GELU_COEF_A);
__m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI);
for (; i + 7 < n; i += 8) {
_mm256_storeu_ps(z + i, _mm256_mul_ps(ggml_v_gelu(_mm256_loadu_ps(x + i), c1, c2), _mm256_loadu_ps(y + i)));
}
#endif
#ifdef GGML_GELU_FP16
uint16_t t;
for (; i < n; ++i) {
if (x[i] <= -10.0f) {
z[i] = 0.0f;
} else if (x[i] >= 10.0f) {
z[i] = x[i]*y[i];
} else {
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
memcpy(&t, &fp16, sizeof(uint16_t));
z[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t])*y[i];
}
}
#else
#if defined __ARM_NEON
float32x4_t c1 = vdupq_n_f32(GELU_COEF_A);
float32x4_t c2 = vdupq_n_f32(2.f*SQRT_2_OVER_PI);
for (; i + 3 < n; i += 4) {
vst1q_f32(z + i, vmulq_f32(ggml_v_gelu(vld1q_f32(x + i), c1, c2), vld1q_f32(y + i)));
}
#endif
for (; i < n; ++i) {
z[i] = ggml_gelu_f32(x[i])*y[i];
}
#endif
}

static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
int i = 0;
ggml_float sum = 0;
Expand Down Expand Up @@ -7380,6 +7446,67 @@ static void ggml_compute_forward_silu(
}
}
}

// ggml_compute_forward_fused_mul_unary

static void ggml_compute_forward_fused_mul_unary_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0];

assert(ggml_is_contiguous_1(src0));
assert(ggml_are_same_shape(src0, dst));
assert(ggml_are_same_shape(src0, src1));
assert(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU);

const int ith = params->ith;
const int nth = params->nth;

const int nc = dst->ne[0];
const int nr = ggml_nrows(src0);


// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

for (int i1 = ir0; i1 < ir1; i1++) {
float * z = (float *) ((char *) dst->data + i1*( dst->nb[1]));
const float * x = (const float *) ((char *) src0->data + i1*(src0->nb[1]));
const float * y = (const float *) ((char *) src1->data + i1*(src1->nb[1]));
switch (op) {
case GGML_UNARY_OP_GELU: ggml_vec_gelu_f32(nc, z, x); ggml_vec_mul_f32(nc, z, z, y); break;
case GGML_UNARY_OP_RELU: ggml_vec_relu_f32(nc, z, x); ggml_vec_mul_f32(nc, z, z, y); break;
case GGML_UNARY_OP_SILU: ggml_vec_mul_silu_f32(nc, z, x, y); break;
default: GGML_ABORT("fatal error");
}
}
}

static void ggml_compute_forward_fused_mul_unary(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

const struct ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_fused_mul_unary_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

// ggml_compute_forward_leaky_relu

static void ggml_compute_forward_leaky_relu_f32(
Expand Down Expand Up @@ -13788,6 +13915,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
break;
case GGML_OP_SILU_BACK:
case GGML_OP_MUL:
case GGML_OP_FUSED_MUL_UNARY:
case GGML_OP_DIV:
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2257,6 +2257,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_MUL:
ggml_cuda_op_mul(ctx, dst);
break;
case GGML_OP_FUSED_MUL_UNARY:
ggml_cuda_op_fused_mul_unary(ctx, dst);
break;
case GGML_OP_DIV:
ggml_cuda_op_div(ctx, dst);
break;
Expand Down Expand Up @@ -3057,6 +3060,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return false;
}
break;
case GGML_OP_FUSED_MUL_UNARY: return ggml_is_contiguous(op->src[0]);
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
Expand Down
67 changes: 67 additions & 0 deletions ggml/src/ggml-cuda/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,36 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i]));
}

static __global__ void fused_mul_silu_f32(const float * x, const float * y, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}
dst[i] = x[i] * y[i] / (1.0f + expf(-x[i]));
}

static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}
dst[i] = fmaxf(x[i], 0) * y[i];
}

static __global__ void fused_mul_gelu_f32(const float * x, const float * y, float * dst, const int k) {
constexpr float GELU_COEF_A = 0.044715f;
constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}
float xi = x[i];
dst[i] = 0.5f*xi*y[i]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
}

static __global__ void tanh_f32(const float * x, float * dst, int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
Expand Down Expand Up @@ -173,6 +203,21 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}

static void fused_mul_silu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
fused_mul_silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
}

static void fused_mul_relu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
fused_mul_relu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
}

static void fused_mul_gelu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
fused_mul_gelu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
}

static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
Expand Down Expand Up @@ -284,6 +329,28 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}

void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src0, src1));

cudaStream_t stream = ctx.stream();
ggml_unary_op op = (ggml_unary_op)dst->op_params[0];

const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;

switch (op) {
case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break;
default: GGML_ASSERT(false);
}
}

void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/unary.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3 changes: 1 addition & 2 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1017,12 +1017,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_SOFTCAP:
return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op);
case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM:
case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction;
case GGML_OP_RMS_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_NORM:
case GGML_OP_ROPE:
return true;
Expand Down
Loading

0 comments on commit 4ac7ea8

Please sign in to comment.