Skip to content

Commit

Permalink
Cleaner pipeline no prefix cache setting (#1022)
Browse files Browse the repository at this point in the history
* Fixes for prefix cache + llama vision

* Fix for vllama
  • Loading branch information
EricLBuehler authored Jan 2, 2025
1 parent f1c3a36 commit 0875194
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 37 deletions.
18 changes: 7 additions & 11 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,30 +72,26 @@ impl Engine {
pipeline: Arc<Mutex<dyn Pipeline>>,
config: SchedulerConfig,
truncate_sequence: bool,
no_kv_cache: bool,
no_prefix_cache: bool,
mut no_kv_cache: bool,
mut no_prefix_cache: bool,
prefix_cache_n: usize,
disable_eos_stop: bool,
throughput_logging_enabled: bool,
) -> Self {
let device = get_mut_arcmutex!(pipeline).device().clone();
let has_no_kv_cache = get_mut_arcmutex!(pipeline).get_metadata().has_no_kv_cache;
if no_kv_cache {
// Diffusion models...
assert_eq!(has_no_kv_cache, no_kv_cache);
}
// Prefix caching is always disabled if using PagedAttention for now.
// TODO
no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
no_prefix_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_prefix_cache;
// TODO: Prefix caching is always disabled if using PagedAttention for now.
let no_prefix_cache = matches!(config, SchedulerConfig::PagedAttentionMeta { .. })
|| no_prefix_cache
|| has_no_kv_cache;
|| no_kv_cache;
Self {
rx,
pipeline,
scheduler: config.into_scheduler(),
id: 0,
truncate_sequence,
no_kv_cache: no_kv_cache & !has_no_kv_cache,
no_kv_cache,
prefix_cacher: PrefixCacheManagerV2::new(device, prefix_cache_n, no_prefix_cache),
is_debug: DEBUG.load(Ordering::Relaxed),
disable_eos_stop,
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/pipeline/amoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
true, // Always a prompt
metadata.is_xlora,
&device,
metadata.has_no_kv_cache,
metadata.no_kv_cache,
None,
false,
input_processor_cfg.clone(),
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/cache_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for FullCach
seqs,
SeqCache::Normal,
);
if pipeline.get_metadata().is_xlora && !pipeline.get_metadata().has_no_kv_cache {
if pipeline.get_metadata().is_xlora && !pipeline.get_metadata().no_kv_cache {
clone_in_cache(
pipeline.get_metadata().num_hidden_layers,
&mut pipeline.cache().full().xlora_lock(),
Expand Down Expand Up @@ -714,7 +714,7 @@ impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for FullCach
seqs,
SeqCache::Normal,
);
if pipeline.get_metadata().is_xlora && !pipeline.get_metadata().has_no_kv_cache {
if pipeline.get_metadata().is_xlora && !pipeline.get_metadata().no_kv_cache {
clone_out_cache(
pipeline.get_metadata().num_hidden_layers,
&mut pipeline.cache().full().xlora_lock(),
Expand Down
3 changes: 2 additions & 1 deletion mistralrs-core/src/pipeline/diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,11 @@ impl Loader for DiffusionLoader {
max_seq_len,
tok_env: None,
is_xlora: false,
no_prefix_cache: false,
num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
eos_tok: vec![],
kind: self.kind.clone(),
has_no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
activation_dtype: dtype,
sliding_window: None,
cache_config: None,
Expand Down
3 changes: 2 additions & 1 deletion mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ impl Loader for GGMLLoader {
metadata: Arc::new(GeneralMetadata {
max_seq_len,
tok_env: Some(tok_env),
has_no_kv_cache: self.no_kv_cache,
no_kv_cache: self.no_kv_cache,
no_prefix_cache: false,
num_hidden_layers,
eos_tok: eos,
kind: self.kind.clone(),
Expand Down
3 changes: 2 additions & 1 deletion mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,8 @@ impl Loader for GGUFLoader {
metadata: Arc::new(GeneralMetadata {
max_seq_len,
tok_env: Some(tok_env),
has_no_kv_cache: self.no_kv_cache,
no_kv_cache: self.no_kv_cache,
no_prefix_cache: false,
num_hidden_layers,
eos_tok: eos,
kind: self.kind.clone(),
Expand Down
7 changes: 4 additions & 3 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ pub struct GeneralMetadata {
pub max_seq_len: usize,
/// Only None if it doesnt make sense for the model
pub tok_env: Option<llguidance::toktrie::TokEnv>,
pub has_no_kv_cache: bool,
pub no_kv_cache: bool,
pub no_prefix_cache: bool,
pub num_hidden_layers: usize,
pub eos_tok: Vec<u32>,
pub kind: ModelKind,
Expand Down Expand Up @@ -322,7 +323,7 @@ pub trait Pipeline:
is_prompt,
self.get_metadata().is_xlora,
&self.device(),
self.get_metadata().has_no_kv_cache,
self.get_metadata().no_kv_cache,
None,
return_raw_logits,
self.get_input_processor_config(),
Expand Down Expand Up @@ -535,7 +536,7 @@ pub trait Pipeline:
is_prompt,
self.get_metadata().is_xlora,
&self.device(),
self.get_metadata().has_no_kv_cache,
self.get_metadata().no_kv_cache,
None,
return_raw_logits,
self.get_input_processor_config(),
Expand Down
3 changes: 2 additions & 1 deletion mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,8 @@ impl Loader for NormalLoader {
metadata: Arc::new(GeneralMetadata {
max_seq_len,
tok_env: Some(tok_env),
has_no_kv_cache: self.no_kv_cache,
no_kv_cache: self.no_kv_cache,
no_prefix_cache: is_xlora,
num_hidden_layers,
eos_tok: eos,
kind: self.kind.clone(),
Expand Down
11 changes: 4 additions & 7 deletions mistralrs-core/src/pipeline/speculative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,7 @@ impl Pipeline for SpeculativePipeline {
for i in 0..self.gamma {
let is_xlora = get_mut_arcmutex!(self.draft).get_metadata().is_xlora;
let device = get_mut_arcmutex!(self.draft).device();
let has_no_kv_cache =
get_mut_arcmutex!(self.draft).get_metadata().has_no_kv_cache;
let no_kv_cache = get_mut_arcmutex!(self.draft).get_metadata().no_kv_cache;
let inputs = self
.get_processor()
.inputs_processor()
Expand All @@ -429,7 +428,7 @@ impl Pipeline for SpeculativePipeline {
is_prompt && i == 0, // Only prompt (no kv cache) if first
is_xlora,
&device,
has_no_kv_cache,
no_kv_cache,
None,
false,
None,
Expand Down Expand Up @@ -492,9 +491,7 @@ impl Pipeline for SpeculativePipeline {
// ========= Run the model ============
let is_xlora = get_mut_arcmutex!(self.target).get_metadata().is_xlora;
let device = get_mut_arcmutex!(self.target).device();
let has_no_kv_cache = get_mut_arcmutex!(self.target)
.get_metadata()
.has_no_kv_cache;
let no_kv_cache = get_mut_arcmutex!(self.target).get_metadata().no_kv_cache;
let inputs = self
.get_processor()
.inputs_processor()
Expand All @@ -504,7 +501,7 @@ impl Pipeline for SpeculativePipeline {
true, // use the "prefill" tokens
is_xlora,
&device,
has_no_kv_cache,
no_kv_cache,
Some((self.gamma, initial_cache_len)), // Get the last gamma, see above
false,
None,
Expand Down
3 changes: 2 additions & 1 deletion mistralrs-core/src/pipeline/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,8 @@ impl Loader for VisionLoader {
num_hidden_layers,
eos_tok: eos,
kind: self.kind.clone(),
has_no_kv_cache: false,
no_kv_cache: false,
no_prefix_cache: true, // TODO: evaluate. Do vision models need to not have prefix caching?
activation_dtype: dtype,
sliding_window,
cache_config,
Expand Down
12 changes: 4 additions & 8 deletions mistralrs-core/src/vision_models/mllama/text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ impl MLlamaTextCrossAttention {
hidden_states: &Tensor,
cross_attn_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
kv_cache: &mut KvCache,
) -> Result<Tensor> {
let (bs, q_len, _) = hidden_states.dims3()?;

Expand Down Expand Up @@ -385,10 +384,9 @@ impl MLlamaTextCrossAttention {
.reshape((bs, (), self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;

(k, v) = kv_cache.append(&k, &v)?;
(k, v)
} else {
candle_core::bail!("Cross attn cannot find cross attn hidden states!")
candle_core::bail!("Cross attn cannot find k,v cache or cross attn hidden states!")
};

let mut attn_output = Sdpa
Expand Down Expand Up @@ -473,15 +471,14 @@ impl MLlamaCrossAttentionDecoderLayer {
cross_attn_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
full_text_row_masked_out_mask: Option<&Tensor>,
kv_cache: &mut KvCache,
) -> Result<Tensor> {
let residual = hidden_states;

let mut hidden_states = self.input_layernorm.forward(hidden_states)?;

hidden_states =
self.attn
.forward(&hidden_states, cross_attn_states, attention_mask, kv_cache)?;
hidden_states = self
.attn
.forward(&hidden_states, cross_attn_states, attention_mask)?;
hidden_states = (residual + hidden_states.broadcast_mul(&self.attn_gate.tanh()?)?)?;

let residual = &hidden_states;
Expand Down Expand Up @@ -674,7 +671,6 @@ impl MLlamaTextModel {
cross_attn_states,
cross_attention_mask,
full_text_row_masked_out_mask,
&mut cache[i],
)?;
}
}
Expand Down

0 comments on commit 0875194

Please sign in to comment.