Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add GGML_UNARY_OP_ARGMAX Metal kernel #1019

Merged
merged 11 commits into from
Dec 2, 2024
28 changes: 28 additions & 0 deletions src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
GGML_METAL_KERNEL_TYPE_ARGMAX,

GGML_METAL_KERNEL_TYPE_COUNT
};
Expand Down Expand Up @@ -869,6 +870,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
}
Expand Down Expand Up @@ -996,6 +998,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction;
case GGML_OP_ARGMAX:
case GGML_OP_NORM:
case GGML_OP_ROPE:
return true;
Expand Down Expand Up @@ -3469,6 +3472,31 @@ static void ggml_metal_encode_node(

[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
} break;
case GGML_OP_ARGMAX:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(nb00 == ggml_type_size(src0->type));

const int64_t nrows = ggml_nrows(src0);

int nth = 32; // SIMD width
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
nth *= 2;
}

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline;

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1];

[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
Expand Down
57 changes: 57 additions & 0 deletions src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,63 @@ kernel void kernel_ssm_scan_f32(
}
}

kernel void kernel_argmax(
device const void * x,
device int32_t * dst,
constant int64_t & ncols,
constant uint64_t & nb01,
threadgroup float * shared_maxval [[threadgroup(0)]],
threadgroup int32_t * shared_argmax [[threadgroup(1)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01);
PABannier marked this conversation as resolved.
Show resolved Hide resolved

float lmax = -INFINITY;
int32_t larg = -1;

for (int i00 = tpitg; i00 < ncols; i00 += ntg) {
if (x_row[i00] > lmax) {
lmax = x_row[i00];
larg = i00;
}
}

// find the argmax value in the block
float max_val = simd_max(lmax);
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));

if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
shared_maxval[tiisg] = -INFINITY;
shared_argmax[tiisg] = -1;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

if (tiisg == 0) {
shared_maxval[sgitg] = max_val;
shared_argmax[sgitg] = arg_val;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

max_val = shared_maxval[tiisg];
arg_val = shared_argmax[tiisg];

float max_val_reduced = simd_max(max_val);
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));

dst[tgpig] = arg_val_reduced;

return;
}

dst[tgpig] = arg_val;
}

kernel void kernel_norm(
device const void * src0,
device float * dst,
Expand Down
5 changes: 4 additions & 1 deletion tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3440,7 +3440,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));

test_cases.emplace_back(new test_argmax());
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, { 100, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438, 3, 1, 1}));

test_cases.emplace_back(new test_count_equal());

for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
Expand Down