Skip to content

Commit

Permalink
Prefix cacher v2 (#1000)
Browse files Browse the repository at this point in the history
* Work on prefix cacher

* It works

* Clippy

* Enable partial matches
  • Loading branch information
EricLBuehler authored Dec 20, 2024
1 parent 3a26a46 commit fc65371
Show file tree
Hide file tree
Showing 16 changed files with 271 additions and 42 deletions.
16 changes: 5 additions & 11 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::{
text_models_inputs_processor::PagedAttentionMeta,
AdapterInstruction, CacheBackendMetadata, CacheInstruction, EitherCache, NormalCache,
},
prefix_cacher_v2::PrefixCacheManagerV2,
request::{DetokenizationRequest, NormalRequest, TokenizationRequest},
response::CompletionChoice,
scheduler::{Scheduler, SchedulerOutput},
Expand All @@ -32,7 +33,6 @@ use tracing::{info, warn};
use crate::{
get_mut_arcmutex, handle_pipeline_forward_error, handle_seq_error,
pipeline::Pipeline,
prefix_cacher::PrefixCacheManager,
request::Request,
response::{ChatCompletionResponse, Choice, ResponseMessage},
sampler::Sampler,
Expand All @@ -59,7 +59,7 @@ pub struct Engine {
id: usize,
truncate_sequence: bool,
no_kv_cache: bool,
prefix_cacher: PrefixCacheManager,
prefix_cacher: PrefixCacheManagerV2,
is_debug: bool,
disable_eos_stop: bool,
throughput_logging_enabled: bool,
Expand All @@ -79,7 +79,6 @@ impl Engine {
throughput_logging_enabled: bool,
) -> Self {
let device = get_mut_arcmutex!(pipeline).device().clone();
let is_xlora = get_mut_arcmutex!(pipeline).get_metadata().is_xlora;
let has_no_kv_cache = get_mut_arcmutex!(pipeline).get_metadata().has_no_kv_cache;
if no_kv_cache {
// Diffusion models...
Expand All @@ -97,12 +96,7 @@ impl Engine {
id: 0,
truncate_sequence,
no_kv_cache: no_kv_cache & !has_no_kv_cache,
prefix_cacher: PrefixCacheManager::new(
device,
prefix_cache_n,
is_xlora,
no_prefix_cache,
),
prefix_cacher: PrefixCacheManagerV2::new(device, prefix_cache_n, no_prefix_cache),
is_debug: DEBUG.load(Ordering::Relaxed),
disable_eos_stop,
throughput_logging_enabled,
Expand Down Expand Up @@ -905,10 +899,10 @@ impl Engine {
request.return_raw_logits,
);
let seq = if let Some(prefill_cache) = prefill_cache.clone() {
seq.prefill(
seq.prefill_v2(
prefill_cache.normal,
prefill_cache.xlora,
prefill_cache.toks,
prefill_cache.offset,
)
} else {
seq
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ mod attention;
mod diffusion_models;
mod pipeline;
mod prefix_cacher;
mod prefix_cacher_v2;
mod request;
mod response;
mod sampler;
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/amoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use tracing::{info, warn};
use crate::{
amoe::{AnyMoeConfig, AnyMoeTrainingInputRow, AnyMoeTrainingInputs, AnyMoeTrainingResult},
get_mut_arcmutex,
prefix_cacher::PrefixCacheManager,
prefix_cacher_v2::PrefixCacheManagerV2,
sampler::Sampler,
sequence::{SeqStepType, Sequence, SequenceGroup, SequenceRecognizer},
utils::progress::NiceProgressBar,
Expand Down Expand Up @@ -260,7 +260,7 @@ impl Pipeline for AnyMoePipeline {
&self,
seqs: &mut [&mut Sequence],
logits: Vec<Tensor>,
prefix_cacher: &mut PrefixCacheManager,
prefix_cacher: &mut PrefixCacheManagerV2,
disable_eos_stop: bool,
rng: Arc<std::sync::Mutex<Isaac64Rng>>,
) -> Result<(), candle_core::Error> {
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::{
use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
use crate::paged_attention::AttentionImplementation;
use crate::pipeline::ChatTemplate;
use crate::prefix_cacher::PrefixCacheManager;
use crate::prefix_cacher_v2::PrefixCacheManagerV2;
use crate::sequence::Sequence;
use crate::utils::debug::DeviceRepr;
use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
Expand Down Expand Up @@ -329,7 +329,7 @@ impl Pipeline for DiffusionPipeline {
&self,
_seqs: &mut [&mut Sequence],
_logits: Vec<Tensor>,
_prefix_cacher: &mut PrefixCacheManager,
_prefix_cacher: &mut PrefixCacheManagerV2,
_disable_eos_stop: bool,
_srng: Arc<std::sync::Mutex<Isaac64Rng>>,
) -> Result<(), candle_core::Error> {
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
use crate::pipeline::get_chat_template;
use crate::pipeline::sampling::sample_and_add_toks;
use crate::pipeline::{ChatTemplate, LocalModelPaths};
use crate::prefix_cacher::PrefixCacheManager;
use crate::prefix_cacher_v2::PrefixCacheManagerV2;
use crate::sequence::Sequence;
use crate::utils::debug::DeviceRepr;
use crate::utils::model_config as ModelConfig;
Expand Down Expand Up @@ -582,7 +582,7 @@ impl Pipeline for GGMLPipeline {
&self,
seqs: &mut [&mut Sequence],
logits: Vec<Tensor>,
prefix_cacher: &mut PrefixCacheManager,
prefix_cacher: &mut PrefixCacheManagerV2,
disable_eos_stop: bool,
rng: Arc<std::sync::Mutex<Isaac64Rng>>,
) -> Result<(), candle_core::Error> {
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkTok, Gener
use crate::pipeline::get_chat_template;
use crate::pipeline::sampling::sample_and_add_toks;
use crate::pipeline::ChatTemplate;
use crate::prefix_cacher::PrefixCacheManager;
use crate::prefix_cacher_v2::PrefixCacheManagerV2;
use crate::sequence::Sequence;
use crate::utils::debug::DeviceRepr;
use crate::utils::model_config as ModelConfig;
Expand Down Expand Up @@ -773,7 +773,7 @@ impl Pipeline for GGUFPipeline {
&self,
seqs: &mut [&mut Sequence],
logits: Vec<Tensor>,
prefix_cacher: &mut PrefixCacheManager,
prefix_cacher: &mut PrefixCacheManagerV2,
disable_eos_stop: bool,
rng: Arc<std::sync::Mutex<Isaac64Rng>>,
) -> Result<(), candle_core::Error> {
Expand Down
8 changes: 7 additions & 1 deletion mistralrs-core/src/pipeline/inputs_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,9 +533,15 @@ pub mod text_models_inputs_processor {
"PagedAttention does not yet support prompt batching.",
))));
}
let offset = input_seqs[0].token_offset();
if offset != 0 && paged_attn_metadata.is_some() {
return Box::new(std::iter::once(Err(anyhow::Error::msg(
"PagedAttention does not yet support sequences with an offset != 0.",
))));
}
Box::new(std::iter::once(
make_prompt_chunk(
0,
offset,
toks,
&input_seqs.iter().map(|s| *s.id()).collect::<Vec<_>>(),
device,
Expand Down
8 changes: 4 additions & 4 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mod vision;
pub use super::diffusion_models::DiffusionGenerationParams;
use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
use crate::prefix_cacher::PrefixCacheManager;
use crate::prefix_cacher_v2::PrefixCacheManagerV2;
pub use amoe::{AnyMoeLoader, AnyMoePipeline};
use chat_template::ChatTemplate;
pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder, DiffusionSpecificConfig};
Expand Down Expand Up @@ -60,7 +60,7 @@ use candle_core::{DType, Device, IndexOp, Tensor, Var};
use crate::sequence::Sequence;

pub use self::cache_manager::{
Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache,
Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, SingleCache,
};
pub use self::inputs_processor::{
text_models_inputs_processor, InputsProcessor, InputsProcessorType,
Expand Down Expand Up @@ -307,7 +307,7 @@ pub trait Pipeline:
input_seqs: &mut [&mut Sequence],
is_prompt: bool,
return_raw_logits: bool,
prefix_cacher: &mut PrefixCacheManager,
prefix_cacher: &mut PrefixCacheManagerV2,
disable_eos_stop: bool,
rng: Arc<std::sync::Mutex<Isaac64Rng>>,
backend_metadata: CacheBackendMetadata<'_>,
Expand Down Expand Up @@ -657,7 +657,7 @@ pub trait Pipeline:
&self,
seqs: &mut [&mut Sequence],
logits: Vec<Tensor>,
prefix_cacher: &mut PrefixCacheManager,
prefix_cacher: &mut PrefixCacheManagerV2,
disable_eos_stop: bool,
rng: Arc<std::sync::Mutex<Isaac64Rng>>,
) -> Result<(), candle_core::Error>;
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::pipeline::isq::UqffFullSer;
use crate::pipeline::sampling::sample_and_add_toks;
use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
use crate::pipeline::{ChatTemplate, LocalModelPaths};
use crate::prefix_cacher::PrefixCacheManager;
use crate::prefix_cacher_v2::PrefixCacheManagerV2;
use crate::sequence::Sequence;
use crate::utils::debug::DeviceRepr;
use crate::utils::tokenizer::get_tokenizer;
Expand Down Expand Up @@ -788,7 +788,7 @@ impl Pipeline for NormalPipeline {
&self,
seqs: &mut [&mut Sequence],
logits: Vec<Tensor>,
prefix_cacher: &mut PrefixCacheManager,
prefix_cacher: &mut PrefixCacheManagerV2,
disable_eos_stop: bool,
rng: Arc<std::sync::Mutex<Isaac64Rng>>,
) -> Result<(), candle_core::Error> {
Expand Down
6 changes: 3 additions & 3 deletions mistralrs-core/src/pipeline/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use candle_core::{DType, Device, Result, Tensor};
use rand_isaac::Isaac64Rng;

use crate::{
prefix_cacher::PrefixCacheManager,
prefix_cacher_v2::PrefixCacheManagerV2,
sampler::Logprobs,
sequence::{Sequence, SequenceRecognizer},
};
Expand All @@ -13,7 +13,7 @@ use super::Pipeline;

pub(crate) async fn finish_or_add_toks_to_seq(
this: &dyn Pipeline,
prefix_cacher: &mut PrefixCacheManager,
prefix_cacher: &mut PrefixCacheManagerV2,
seq: &mut Sequence,
logprobs: Logprobs,
eos_tok: Option<&[u32]>,
Expand Down Expand Up @@ -245,7 +245,7 @@ pub async fn sample_and_add_toks(
this: &dyn Pipeline,
seqs: &mut [&mut Sequence],
logits_seq: Vec<Tensor>,
prefix_cacher: &mut PrefixCacheManager,
prefix_cacher: &mut PrefixCacheManagerV2,
disable_eos_stop: bool,
rng: Arc<std::sync::Mutex<Isaac64Rng>>,
) -> Result<()> {
Expand Down
6 changes: 3 additions & 3 deletions mistralrs-core/src/pipeline/speculative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
},
AdapterInstruction,
},
prefix_cacher::PrefixCacheManager,
prefix_cacher_v2::PrefixCacheManagerV2,
sequence::{Sequence, SequenceRecognizer},
DeviceMapMetadata, Loader, ModelKind, PagedAttentionConfig, Pipeline, TokenSource,
TryIntoDType,
Expand Down Expand Up @@ -328,7 +328,7 @@ impl Pipeline for SpeculativePipeline {
&self,
_seqs: &mut [&mut Sequence],
_logits: Vec<Tensor>,
_prefix_cacher: &mut PrefixCacheManager,
_prefix_cacher: &mut PrefixCacheManagerV2,
_disable_eos_stop: bool,
_rng: Arc<std::sync::Mutex<Isaac64Rng>>,
) -> Result<()> {
Expand All @@ -339,7 +339,7 @@ impl Pipeline for SpeculativePipeline {
input_seqs: &mut [&mut Sequence],
is_prompt: bool,
_return_raw_logits: bool,
prefix_cacher: &mut PrefixCacheManager,
prefix_cacher: &mut PrefixCacheManagerV2,
disable_eos_stop: bool,
rng: Arc<Mutex<Isaac64Rng>>,
backend_metadata: CacheBackendMetadata<'_>,
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::pipeline::llg::build_tok_env;
use crate::pipeline::sampling::sample_and_add_toks;
use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
use crate::pipeline::{get_chat_template, ChatTemplate, IsqOrganization, LocalModelPaths};
use crate::prefix_cacher::PrefixCacheManager;
use crate::prefix_cacher_v2::PrefixCacheManagerV2;
use crate::sequence::Sequence;
use crate::utils::debug::DeviceRepr;
use crate::utils::tokenizer::get_tokenizer;
Expand Down Expand Up @@ -661,7 +661,7 @@ impl Pipeline for VisionPipeline {
&self,
seqs: &mut [&mut Sequence],
logits: Vec<Tensor>,
prefix_cacher: &mut PrefixCacheManager,
prefix_cacher: &mut PrefixCacheManagerV2,
disable_eos_stop: bool,
rng: Arc<std::sync::Mutex<Isaac64Rng>>,
) -> Result<(), candle_core::Error> {
Expand Down
3 changes: 3 additions & 0 deletions mistralrs-core/src/prefix_cacher.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![allow(dead_code, deprecated)]

use std::sync::{Arc, Mutex};

use candle_core::{Device, Result, Tensor};
Expand Down Expand Up @@ -25,6 +27,7 @@ impl From<Vec<u32>> for Tokens {

type EvictionCacheGroup = (Arc<Mutex<LayerCaches>>, Option<Arc<Mutex<LayerCaches>>>);

#[deprecated(note = "use PrefixCacheManagerV2 instead!")]
pub struct PrefixCacheManager {
caches: Trie<Tokens, Arc<Mutex<LayerCaches>>>,
xlora_caches: Option<Trie<Tokens, Arc<Mutex<LayerCaches>>>>,
Expand Down
Loading

0 comments on commit fc65371

Please sign in to comment.