Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Dec 23, 2024
1 parent afb8ce4 commit 90e86a1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
8 changes: 3 additions & 5 deletions mistralrs-paged-attn/src/metal/backend/paged_attention.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use candle_core::{
backend::BackendStorage, CpuStorage, DType, Device, Layout, MetalStorage, Result, Shape,
Storage, Tensor,
backend::BackendStorage, CpuStorage, DType, Layout, MetalStorage, Result, Shape, Storage,
Tensor,
};

use crate::metal::kernels::{self, PagedAttentionDType};
Expand Down Expand Up @@ -412,9 +412,7 @@ pub fn reshape_and_cache(
let key_stride = k_l.stride()[0] as i32;
let value_stride = v_l.stride()[0] as i32;

let Device::Metal(dev) = key.device() else {
panic!("Expected the key to be on a Metal device.")
};
let dev = key.device().as_metal_device()?;

let command_buffer = dev.command_buffer()?;
command_buffer.set_label("reshape-and-cache");
Expand Down
3 changes: 1 addition & 2 deletions mistralrs-paged-attn/src/metal/kernels/pagedattention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ inline float qk_dot(device T* q[N], device T* k[N]) {
// Compute the parallel products then sum for Q*K^T (treat vector lanes separately).
float qk = 0;
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
#pragma unroll
for (int ii = 0; ii < N; ++ii) {
for (int vi = 0; vi < VEC_SIZE; ++vi) {
qk = fma(float(q[ii][vi]), float(k[ii][vi]), qk);
}
Expand Down

0 comments on commit 90e86a1

Please sign in to comment.