Skip to content

Commit

Permalink
🚀 It works!
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Dec 24, 2024
1 parent 599ff2d commit 3adca5c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
11 changes: 2 additions & 9 deletions mistralrs-core/src/paged_attention/layers/paged_attention.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use candle_core::{DType, Device, Result, Tensor};
use candle_core::{Device, Result, Tensor};

use mistralrs_paged_attn::{paged_attention, reshape_and_cache};

Expand Down Expand Up @@ -132,11 +132,6 @@ impl PagedAttention {
return Ok(att);
}

#[cfg(feature = "cuda")]
let p = 'c';
#[cfg(feature = "metal")]
let p = 'm';
key_cache.as_ref().unwrap().to_dtype(DType::F32)?.write_npy(&format!("{p}-key-cache.npy"))?;
// Args:
// output: shape = [num_generation_tokens, num_heads, head_size]
//
Expand Down Expand Up @@ -164,8 +159,6 @@ impl PagedAttention {
softcapping.unwrap_or(1.0f64) as f32,
)?;

res.to_dtype(DType::F32)?.write_npy(&format!("{p}-res.npy"))?;

panic!();
Ok(res)
}
}
18 changes: 14 additions & 4 deletions mistralrs-paged-attn/src/metal/kernels/pagedattention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ typedef struct _MLX_BFloat16 bfloat16_t;

// TODO(EricLBuehler): optimize with vectorization
template<int THREAD_GROUP_SIZE, int VEC_SIZE, typename T, int N>
inline float qk_dot(device T* q[N], device T* k[N]) {
inline float qk_dot(threadgroup T q[N][VEC_SIZE], device T* k[N]) {
// Compute the parallel products then sum for Q*K^T (treat vector lanes separately).
float qk = 0;
#pragma unroll
Expand Down Expand Up @@ -394,7 +394,7 @@ template <typename T, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, int NUM_SI
const int num_heads = threadgroups_per_grid.x;
const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = use_alibi ? 0.f : alibi_slopes[head_idx];
const float alibi_slope = !use_alibi ? 0.f : alibi_slopes[head_idx];

// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group
Expand All @@ -415,11 +415,21 @@ template <typename T, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, int NUM_SI
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
const device T* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
device T* q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
threadgroup T q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD][VEC_SIZE];
// for (int i = 0; i < THREAD_GROUP_SIZE; ++i) {
// for (int j = 0; j < NUM_VECS_PER_THREAD; ++j) {
// for (int k = 0; k < VEC_SIZE; ++k) {
// q_vecs[i][j][k] = T(-1000.f);
// }
// }
// }
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] = const_cast<device T*>(q_ptr) + vec_idx * VEC_SIZE;
const device T* q_vec_ptr = q_ptr + vec_idx * VEC_SIZE;
for (int vi = 0; vi < VEC_SIZE; ++vi) {
q_vecs[thread_group_offset][i][vi] = q_vec_ptr[vi];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);

Expand Down

0 comments on commit 3adca5c

Please sign in to comment.