Skip to content

Commit

Permalink
Debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Dec 23, 2024
1 parent 90e86a1 commit 599ff2d
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions mistralrs-core/src/paged_attention/layers/paged_attention.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use candle_core::{Device, Result, Tensor};
use candle_core::{DType, Device, Result, Tensor};

use mistralrs_paged_attn::{paged_attention, reshape_and_cache};

Expand Down Expand Up @@ -131,6 +131,12 @@ impl PagedAttention {
// Return result in prefill
return Ok(att);
}

#[cfg(feature = "cuda")]
let p = 'c';
#[cfg(feature = "metal")]
let p = 'm';
key_cache.as_ref().unwrap().to_dtype(DType::F32)?.write_npy(&format!("{p}-key-cache.npy"))?;
// Args:
// output: shape = [num_generation_tokens, num_heads, head_size]
//
Expand All @@ -146,7 +152,7 @@ impl PagedAttention {
//
// alibi_slopes: shape = [num_heads]
#[allow(clippy::cast_possible_truncation)]
paged_attention(
let res = paged_attention(
&query,
key_cache.as_ref().unwrap(),
value_cache.as_ref().unwrap(),
Expand All @@ -156,6 +162,10 @@ impl PagedAttention {
input_metadata.max_context_len.unwrap(),
self.scale,
softcapping.unwrap_or(1.0f64) as f32,
)
)?;

res.to_dtype(DType::F32)?.write_npy(&format!("{p}-res.npy"))?;

panic!();
}
}

0 comments on commit 599ff2d

Please sign in to comment.