diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 32b2e9c318475..3746bf5a44b2d 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -802,15 +802,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx if (op->src[0]->ne[0] == 256) { return false; } - { - float logit_softcap; - - memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap)); - - if (logit_softcap != 0.0f) { - return false; - } - } return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -2633,9 +2624,14 @@ static enum ggml_status ggml_metal_graph_compute( float scale; float max_bias; + float logit_softcap; + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } const uint32_t n_head = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); @@ -2686,30 +2682,31 @@ static enum ggml_status ggml_metal_graph_compute( } else { [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; - [encoder setBytes:&scale length:sizeof( float) atIndex:23]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; + [encoder setBytes:&scale length:sizeof( float) atIndex:23]; + [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; + [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 3bb37d32aced0..aba0b9a0300b8 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1976,6 +1976,7 @@ typedef void (flash_attn_ext_f16_t)( constant float & m0, constant float & m1, constant uint32_t & n_head_log2, + constant float & logit_softcap, threadgroup half * shared, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2014,6 +2015,7 @@ kernel void kernel_flash_attn_ext_f16( constant float & m0, constant float & m1, constant uint32_t & n_head_log2, + constant float & logit_softcap, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2142,14 +2144,19 @@ kernel void kernel_flash_attn_ext_f16( const short tx = tiisg%4; const short ty = tiisg/4; + // mqk = mqk*scale + ss[8*cc + ty*TF + 2*tx + 0] *= scale; + ss[8*cc + ty*TF + 2*tx + 1] *= scale; + + if (logit_softcap != 0.0f) { + ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]); + ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]); + } + if (mask != q) { - // mqk = mqk*scale + mask*slope - ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; - ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; - } else { - // mqk = mqk*scale - ss[8*cc + ty*TF + 2*tx + 0] *= scale; - ss[8*cc + ty*TF + 2*tx + 1] *= scale; + // mqk = mqk + mask*slope + ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; + ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; } } } @@ -2345,6 +2352,7 @@ kernel void kernel_flash_attn_ext_vec_f16( constant float & m0, constant float & m1, constant uint32_t & n_head_log2, + constant float & logit_softcap, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2479,7 +2487,13 @@ kernel void kernel_flash_attn_ext_vec_f16( // mqk = mqk*scale + mask*slope if (tiisg == 0) { - mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f); + mqk *= scale; + + if (logit_softcap != 0.0f) { + mqk = logit_softcap*precise::tanh(mqk); + } + + mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f; ss4[cc] = mqk; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 351b1d567c7e7..eab5f027f498b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2487,7 +2487,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } GGML_ABORT("fatal error"); - return false; } static void usage(char ** argv) {