From 667d70d1704dfa6977505f5d01d4638669b90dce Mon Sep 17 00:00:00 2001 From: PAB Date: Thu, 28 Nov 2024 09:25:06 +0100 Subject: [PATCH] metal : add `GGML_OP_CONV_TRANSPOSE_1D` kernels (ggml/1026) * wip * wip implementation f32 * kernel conv transpose 1d f32 working * initial commit --- ggml/src/ggml-metal/ggml-metal.m | 48 ++++++++++++++++++ ggml/src/ggml-metal/ggml-metal.metal | 73 ++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index c247b50c9e690..d374b65a4f080 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -306,6 +306,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_IM2COL_F32, GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, + GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, + GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32, GGML_METAL_KERNEL_TYPE_PAD_F32, GGML_METAL_KERNEL_TYPE_ARANGE_F32, @@ -870,6 +872,8 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); @@ -1069,6 +1073,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_REPEAT: case GGML_OP_SCALE: case GGML_OP_CLAMP: + case GGML_OP_CONV_TRANSPOSE_1D: return true; case GGML_OP_SQR: case GGML_OP_SQRT: @@ -3138,6 +3143,49 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; } } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + + const int32_t IC = src1->ne[1]; + const int32_t IL = src1->ne[0]; + + const int32_t K = src0->ne[0]; + + const int32_t OL = dst->ne[0]; + const int32_t OC = dst->ne[1]; + + id pipeline; + + switch (src0->type) { + case GGML_TYPE_F32: { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline; + } break; + case GGML_TYPE_F16: { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline; + } break; + default: GGML_ABORT("fatal error"); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&IC length:sizeof( int32_t) atIndex:3]; + [encoder setBytes:&IL length:sizeof( int32_t) atIndex:4]; + [encoder setBytes:&K length:sizeof( int32_t) atIndex:5]; + [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8]; + + [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_UPSCALE: { GGML_ASSERT(src0->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 7567f326200fc..8cb9a3414974c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2671,6 +2671,79 @@ kernel void kernel_im2col_ext( template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; +typedef void (conv_transpose_1d_t)( + device const float * src0, + device const float * src1, + device char * dst, + constant int32_t & IC, + constant int32_t & IL, + constant int32_t & K, + constant int32_t & s0, + constant uint64_t & nb0, + constant uint64_t & nb1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +template +kernel void kernel_conv_transpose_1d( + device const T * src0, + device const float * src1, + device char * dst, + constant int32_t & IC, + constant int32_t & IL, + constant int32_t & K, + constant int32_t & s0, + constant uint64_t & nb0, + constant uint64_t & nb1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]) { + + float v = 0.0f; + + for (int64_t c = 0; c < IC; c++) { + const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1]; + const int32_t input_offset = c * IL; + + for (int64_t i = 0; i < IL; i++) { + if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) { + v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i]; + } + } + } + + device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1); + + dst_ptr[0] = v; +} + +template [[host_name("kernel_conv_transpose_1d_f32_f32")]] +kernel void kernel_conv_transpose_1d( + device const float * src0, + device const float * src1, + device char * dst, + constant int32_t & IC, + constant int32_t & IL, + constant int32_t & K, + constant int32_t & s0, + constant uint64_t & nb0, + constant uint64_t & nb1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +template [[host_name("kernel_conv_transpose_1d_f16_f32")]] +kernel void kernel_conv_transpose_1d( + device const half * src0, + device const float * src1, + device char * dst, + constant int32_t & IC, + constant int32_t & IL, + constant int32_t & K, + constant int32_t & s0, + constant uint64_t & nb0, + constant uint64_t & nb1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + kernel void kernel_upscale_f32( device const char * src0, device char * dst,