From 3adca5c0d8e6769c9fe59da6347ddd7b05817564 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Tue, 24 Dec 2024 13:07:25 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=80=20It=20works!?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../paged_attention/layers/paged_attention.rs | 11 ++--------- .../src/metal/kernels/pagedattention.metal | 18 ++++++++++++++---- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/mistralrs-core/src/paged_attention/layers/paged_attention.rs b/mistralrs-core/src/paged_attention/layers/paged_attention.rs index c2a82e582..3c26d805d 100644 --- a/mistralrs-core/src/paged_attention/layers/paged_attention.rs +++ b/mistralrs-core/src/paged_attention/layers/paged_attention.rs @@ -1,4 +1,4 @@ -use candle_core::{DType, Device, Result, Tensor}; +use candle_core::{Device, Result, Tensor}; use mistralrs_paged_attn::{paged_attention, reshape_and_cache}; @@ -132,11 +132,6 @@ impl PagedAttention { return Ok(att); } - #[cfg(feature = "cuda")] - let p = 'c'; - #[cfg(feature = "metal")] - let p = 'm'; - key_cache.as_ref().unwrap().to_dtype(DType::F32)?.write_npy(&format!("{p}-key-cache.npy"))?; // Args: // output: shape = [num_generation_tokens, num_heads, head_size] // @@ -164,8 +159,6 @@ impl PagedAttention { softcapping.unwrap_or(1.0f64) as f32, )?; - res.to_dtype(DType::F32)?.write_npy(&format!("{p}-res.npy"))?; - - panic!(); + Ok(res) } } diff --git a/mistralrs-paged-attn/src/metal/kernels/pagedattention.metal b/mistralrs-paged-attn/src/metal/kernels/pagedattention.metal index 43027e2d9..74c878e5e 100644 --- a/mistralrs-paged-attn/src/metal/kernels/pagedattention.metal +++ b/mistralrs-paged-attn/src/metal/kernels/pagedattention.metal @@ -261,7 +261,7 @@ typedef struct _MLX_BFloat16 bfloat16_t; // TODO(EricLBuehler): optimize with vectorization template -inline float qk_dot(device T* q[N], device T* k[N]) { +inline float qk_dot(threadgroup T q[N][VEC_SIZE], device T* k[N]) { // Compute the parallel products then sum for Q*K^T (treat vector lanes separately). float qk = 0; #pragma unroll @@ -394,7 +394,7 @@ template (q_ptr) + vec_idx * VEC_SIZE; + const device T* q_vec_ptr = q_ptr + vec_idx * VEC_SIZE; + for (int vi = 0; vi < VEC_SIZE; ++vi) { + q_vecs[thread_group_offset][i][vi] = q_vec_ptr[vi]; + } } threadgroup_barrier(mem_flags::mem_threadgroup);