Skip to content

Commit

Permalink
feat: add GGML_UNARY_OP_ARGMAX Metal kernel (ggerganov#1019)
Browse files Browse the repository at this point in the history
* implemented argmax kernel

* tpig -> tgpig

* change to strides

* contiguous assertions

* kernel working and tested

* argmax simd parallel implementation

* added 2 new tests for argmax in test-backend-ops

* cosmit

* added 3 tests cases for perf eval

* add test_argmax in make_test_cases_perf

* Update test-backend-ops.cpp

Co-authored-by: Diego Devesa <[email protected]>

---------

Co-authored-by: Diego Devesa <[email protected]>
  • Loading branch information
2 people authored and ypapadop-amd committed Dec 2, 2024
1 parent 103ec8b commit e5206db
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
28 changes: 28 additions & 0 deletions src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
GGML_METAL_KERNEL_TYPE_ARGMAX,

GGML_METAL_KERNEL_TYPE_COUNT
};
Expand Down Expand Up @@ -876,6 +877,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
}
Expand Down Expand Up @@ -1005,6 +1007,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction;
case GGML_OP_ARGMAX:
case GGML_OP_NORM:
case GGML_OP_ROPE:
return true;
Expand Down Expand Up @@ -3615,6 +3618,31 @@ static void ggml_metal_encode_node(

[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
} break;
case GGML_OP_ARGMAX:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(nb00 == ggml_type_size(src0->type));

const int64_t nrows = ggml_nrows(src0);

int nth = 32; // SIMD width
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
nth *= 2;
}

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline;

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1];

[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
Expand Down
57 changes: 57 additions & 0 deletions src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,63 @@ kernel void kernel_ssm_scan_f32(
}
}

kernel void kernel_argmax(
device const void * x,
device int32_t * dst,
constant int64_t & ncols,
constant uint64_t & nb01,
threadgroup float * shared_maxval [[threadgroup(0)]],
threadgroup int32_t * shared_argmax [[threadgroup(1)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01);

float lmax = -INFINITY;
int32_t larg = -1;

for (int i00 = tpitg; i00 < ncols; i00 += ntg) {
if (x_row[i00] > lmax) {
lmax = x_row[i00];
larg = i00;
}
}

// find the argmax value in the block
float max_val = simd_max(lmax);
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));

if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
shared_maxval[tiisg] = -INFINITY;
shared_argmax[tiisg] = -1;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

if (tiisg == 0) {
shared_maxval[sgitg] = max_val;
shared_argmax[sgitg] = arg_val;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

max_val = shared_maxval[tiisg];
arg_val = shared_argmax[tiisg];

float max_val_reduced = simd_max(max_val);
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));

dst[tgpig] = arg_val_reduced;

return;
}

dst[tgpig] = arg_val;
}

kernel void kernel_norm(
constant ggml_metal_kargs_norm & args,
device const char * src0,
Expand Down
5 changes: 4 additions & 1 deletion tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3439,7 +3439,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));

test_cases.emplace_back(new test_argmax());
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, { 100, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438, 3, 1, 1}));

test_cases.emplace_back(new test_count_equal());

for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
Expand Down

0 comments on commit e5206db

Please sign in to comment.