Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add GGML_UNARY_OP_ARGMAX Metal kernel #1019

Merged
merged 11 commits into from
Dec 2, 2024
21 changes: 21 additions & 0 deletions src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
GGML_METAL_KERNEL_TYPE_ARGMAX,

GGML_METAL_KERNEL_TYPE_COUNT
};
Expand Down Expand Up @@ -869,6 +870,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
}
Expand Down Expand Up @@ -996,6 +998,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction;
case GGML_OP_ARGMAX:
case GGML_OP_NORM:
case GGML_OP_ROPE:
return true;
Expand Down Expand Up @@ -3469,6 +3472,24 @@ static void ggml_metal_encode_node(

[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
} break;
case GGML_OP_ARGMAX:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(nb00 == ggml_type_size(src0->type));

const int64_t nrows = ggml_nrows(src0);

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline;

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];

[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
Expand Down
17 changes: 17 additions & 0 deletions src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,23 @@ kernel void kernel_ssm_scan_f32(
}
}

kernel void kernel_argmax(
device const void * x,
device int32_t * dst,
constant int64_t & ncols,
constant uint64_t & nb01,
uint tgpig[[threadgroup_position_in_grid]]) {
device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01);
PABannier marked this conversation as resolved.
Show resolved Hide resolved

dst[tgpig] = 0;

for (int i = 0; i < ncols; i++) {
if (x_row[i] > x_row[dst[tgpig]]) {
dst[tgpig] = i;
}
}
PABannier marked this conversation as resolved.
Show resolved Hide resolved
}

kernel void kernel_norm(
device const void * src0,
device float * dst,
Expand Down
Loading