diff --git a/mistralrs-core/src/device_map.rs b/mistralrs-core/src/device_map.rs index 885ed6ec8..a19c3d930 100644 --- a/mistralrs-core/src/device_map.rs +++ b/mistralrs-core/src/device_map.rs @@ -148,6 +148,7 @@ pub trait DeviceMapper: Debug { ) -> VarBuilder<'a>; /// If ISQ layer, then do not change the device (return None). *They will do it later in NormalModel::quantize* fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device>; + fn get_unique_devices(&self) -> Vec; /// If ISQ layer, then do not change the device (return None). *They will do it later in NormalModel::quantize* fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result; /// Set non mapped layer device. This is for ISQ + device mapping support @@ -186,6 +187,14 @@ impl DeviceMapper for LayerDeviceMapper { } self.mappings.get(layer) } + fn get_unique_devices(&self) -> Vec { + self.mappings.iter().fold(Vec::new(), |mut acc, device| { + if !acc.iter().any(|d| d.same_device(device)) { + acc.push(device.clone()); + } + acc + }) + } fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result { if loading_isq { x.to_device(&Device::Cpu) @@ -234,6 +243,9 @@ impl DeviceMapper for DummyDeviceMapper { } None } + fn get_unique_devices(&self) -> Vec { + vec![self.nm_device.clone()] + } fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result { if loading_isq { x.to_device(&Device::Cpu) diff --git a/mistralrs-core/src/diffusion_models/processor.rs b/mistralrs-core/src/diffusion_models/processor.rs index 14713580e..872212703 100644 --- a/mistralrs-core/src/diffusion_models/processor.rs +++ b/mistralrs-core/src/diffusion_models/processor.rs @@ -6,6 +6,7 @@ use indexmap::IndexMap; use tokenizers::Tokenizer; use crate::{ + device_map::DeviceMapper, pipeline::{ text_models_inputs_processor::PagedAttentionMeta, InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor, @@ -69,6 +70,7 @@ impl InputsProcessor for DiffusionInputsProcessor { _other_config: Option>, _paged_attn_metadata: Option>, prompt_batchsize: Option, + _mapper: Option<&dyn DeviceMapper>, ) -> Box>> { let mut make_value = if prompt_batchsize.is_some() { return Box::new(std::iter::once(Err(anyhow::Error::msg( diff --git a/mistralrs-core/src/dummy_paged_attention/cache_engine.rs b/mistralrs-core/src/dummy_paged_attention/cache_engine.rs index 8dab1636b..c3a064286 100644 --- a/mistralrs-core/src/dummy_paged_attention/cache_engine.rs +++ b/mistralrs-core/src/dummy_paged_attention/cache_engine.rs @@ -26,6 +26,7 @@ impl CacheEngine { _cache_config: &CacheConfig, _dtype: DType, _device: &Device, + _layer_devices: Vec>, ) -> Result { Ok(Self { dummy_cache: Arc::new(Mutex::new(Vec::new())), diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index 5b0cd9811..8567278d5 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -615,6 +615,9 @@ impl Llama3RotaryEmbedding { k: &mut Tensor, b_sz: usize, ) -> Result<()> { + // Needed for device mapping + let positions_kernel = positions_kernel.to_device(q.device())?; + match self { Self::Llama3 { sin, cos, is_gptx } => { let (b_sz_seq_len, h, n_embd) = q.dims3()?; @@ -646,7 +649,7 @@ impl Llama3RotaryEmbedding { *k = Tensor::cat(&k_embeds, 0)?; Ok(()) } - Self::Default(rope) => rope.forward(positions, positions_kernel, q, k, b_sz), + Self::Default(rope) => rope.forward(positions, &positions_kernel, q, k, b_sz), } } } @@ -935,7 +938,10 @@ impl RotaryEmbedding { k: &mut Tensor, b_sz: usize, ) -> Result<()> { - self.0.forward(positions, positions_kernel, q, k, b_sz) + // Needed for device mapping + let positions_kernel = positions_kernel.to_device(q.device())?; + + self.0.forward(positions, &positions_kernel, q, k, b_sz) } } diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index 62f03b2bc..95d390927 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -6,13 +6,13 @@ use std::sync::Arc; use candle_core::quantized::ggml_file; use candle_core::quantized::QTensor; use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::{Embedding, Module, RotaryEmbedding}; +use candle_nn::{Embedding, Module}; use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig}; use crate::attention::SdpaParams; use crate::device_map::DeviceMapper; use crate::gguf::Content; -use crate::layers::{CausalMasker, MatMul, QRmsNorm, Sdpa}; +use crate::layers::{CausalMasker, MatMul, QRmsNorm, RotaryEmbedding, Sdpa}; use crate::layers_masker::PastKvLenCache; use crate::paged_attention::{AttentionImplementation, PagedAttention}; use crate::pipeline::extract_logits; diff --git a/mistralrs-core/src/models/quantized_qwen2.rs b/mistralrs-core/src/models/quantized_qwen2.rs index f14aa592b..17f4f5ada 100644 --- a/mistralrs-core/src/models/quantized_qwen2.rs +++ b/mistralrs-core/src/models/quantized_qwen2.rs @@ -4,13 +4,13 @@ use std::collections::HashMap; use std::sync::Arc; use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::{Embedding, Module, RotaryEmbedding}; +use candle_nn::{Embedding, Module}; use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig}; use crate::attention::SdpaParams; use crate::device_map::DeviceMapper; use crate::gguf::Content; -use crate::layers::{CausalMasker, MatMul, QRmsNorm, Sdpa}; +use crate::layers::{CausalMasker, MatMul, QRmsNorm, RotaryEmbedding, Sdpa}; use crate::layers_masker::PastKvLenCache; use crate::paged_attention::{AttentionImplementation, PagedAttention}; use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata; diff --git a/mistralrs-core/src/paged_attention/cache_engine.rs b/mistralrs-core/src/paged_attention/cache_engine.rs index 3faeb2481..d75a266ab 100644 --- a/mistralrs-core/src/paged_attention/cache_engine.rs +++ b/mistralrs-core/src/paged_attention/cache_engine.rs @@ -29,6 +29,7 @@ impl CacheEngine { cache_config: &CacheConfig, dtype: DType, device: &Device, + layer_devices: Vec>, ) -> Result { Ok(Self { gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache( @@ -36,6 +37,7 @@ impl CacheEngine { cache_config, dtype, device, + layer_devices, )?)), cpu_cache: Self::allocate_cpu_cache(model_config, cache_config, dtype, device)?, num_layers: model_config.num_layers(), @@ -55,13 +57,16 @@ impl CacheEngine { cache_config: &CacheConfig, dtype: DType, device: &Device, + layer_devices: Vec>, ) -> Result> { let key_block_shape = Self::calculate_key_block_shape(model_config, dtype, cache_config.block_size); let value_block_shape = Self::calculate_value_block_shape(model_config, cache_config.block_size); let mut gpu_cache = Vec::new(); - for _ in 0..model_config.num_layers() { + + for i in 0..model_config.num_layers() { + let device = layer_devices[i].as_ref().unwrap_or(device); let key_blocks = Tensor::zeros( ( cache_config.num_gpu_blocks, diff --git a/mistralrs-core/src/paged_attention/layers/paged_attention.rs b/mistralrs-core/src/paged_attention/layers/paged_attention.rs index 3c26d805d..1180e4077 100644 --- a/mistralrs-core/src/paged_attention/layers/paged_attention.rs +++ b/mistralrs-core/src/paged_attention/layers/paged_attention.rs @@ -61,13 +61,34 @@ impl PagedAttention { input_metadata: &mut PagedAttentionInputMetadata, softcapping: Option, ) -> Result { - let dims = input_metadata.slot_mappings.dims(); + let slot_mapping = input_metadata + .slot_mappings + .get(&query.device().location()) + .unwrap(); + let dims = slot_mapping.dims(); let slot_mapping = if dims.len() > 1 { - input_metadata - .slot_mappings - .flatten(0, input_metadata.slot_mappings.dims().len())? + &slot_mapping.flatten(0, dims.len())? } else { - input_metadata.slot_mappings.clone() + slot_mapping + }; + + let block_tables = input_metadata + .block_tables + .as_ref() + .unwrap() + .get(&query.device().location()) + .unwrap(); + let context_lens = input_metadata + .context_lens + .as_ref() + .unwrap() + .get(&query.device().location()) + .unwrap(); + + let alibi_slopes = if let Some(alibi_slopes) = self.alibi_slopes.as_ref() { + Some(alibi_slopes.to_device(query.device())?) + } else { + None }; let (batch_size, attention_heads, seq_len, head_size) = query.shape().dims4()?; @@ -80,7 +101,7 @@ impl PagedAttention { query, key, value, - Some(mask), + Some(&mask), None, &SdpaParams { n_kv_groups: self.n_kv_groups, @@ -92,7 +113,7 @@ impl PagedAttention { )?), }; - // // paged-attn expects [batch_size, num_tokens, num_heads, head_size] + // paged-attn expects [batch_size, num_tokens, num_heads, head_size] let (query, key, value) = if seq_len > 1 { let q = query .transpose(1, 2)? @@ -105,7 +126,7 @@ impl PagedAttention { .reshape(((), key_value_heads, head_size))?; (q, k, v) } else { - //avoid unnecessary transpose for decoding + // avoid unnecessary transpose for decoding let q = query.reshape(((), attention_heads, head_size))?; let k = key.reshape(((), key_value_heads, head_size))?; let v = value.reshape(((), key_value_heads, head_size))?; @@ -123,7 +144,7 @@ impl PagedAttention { &value, key_cache.as_mut().unwrap(), value_cache.as_mut().unwrap(), - &slot_mapping, + slot_mapping, )?; } @@ -131,7 +152,6 @@ impl PagedAttention { // Return result in prefill return Ok(att); } - // Args: // output: shape = [num_generation_tokens, num_heads, head_size] // @@ -147,18 +167,16 @@ impl PagedAttention { // // alibi_slopes: shape = [num_heads] #[allow(clippy::cast_possible_truncation)] - let res = paged_attention( + paged_attention( &query, key_cache.as_ref().unwrap(), value_cache.as_ref().unwrap(), - input_metadata.block_tables.as_ref().unwrap(), - input_metadata.context_lens.as_ref().unwrap(), - self.alibi_slopes.as_ref(), + block_tables, + context_lens, + alibi_slopes.as_ref(), input_metadata.max_context_len.unwrap(), self.scale, softcapping.unwrap_or(1.0f64) as f32, - )?; - - Ok(res) + ) } } diff --git a/mistralrs-core/src/pipeline/amoe.rs b/mistralrs-core/src/pipeline/amoe.rs index 4ba7961d9..65a379d20 100644 --- a/mistralrs-core/src/pipeline/amoe.rs +++ b/mistralrs-core/src/pipeline/amoe.rs @@ -19,6 +19,7 @@ use tracing::{info, warn}; use crate::{ amoe::{AnyMoeConfig, AnyMoeTrainingInputRow, AnyMoeTrainingInputs, AnyMoeTrainingResult}, + device_map::DeviceMapper, get_mut_arcmutex, prefix_cacher_v2::PrefixCacheManagerV2, sampler::Sampler, @@ -244,6 +245,9 @@ impl MetadataMixin for AnyMoePipeline { fn tokenizer(&self) -> Option> { get_mut_arcmutex!(self.target).tokenizer() } + fn device_mapper(&self) -> Option<&dyn DeviceMapper> { + None + } } #[async_trait::async_trait] @@ -469,6 +473,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline { input_processor_cfg.clone(), None, // TODO: get block tables/handle it for PagedAttention None, // TODO: prompt chunking doesn't work. + None, ) .nth(0) .unwrap(); diff --git a/mistralrs-core/src/pipeline/diffusion.rs b/mistralrs-core/src/pipeline/diffusion.rs index d75bc7f87..3d7963759 100644 --- a/mistralrs-core/src/pipeline/diffusion.rs +++ b/mistralrs-core/src/pipeline/diffusion.rs @@ -5,6 +5,7 @@ use super::{ GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths, PreProcessingMixin, Processor, TokenSource, }; +use crate::device_map::DeviceMapper; use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs}; use crate::paged_attention::AttentionImplementation; use crate::pipeline::ChatTemplate; @@ -296,6 +297,9 @@ impl MetadataMixin for DiffusionPipeline { fn tokenizer(&self) -> Option> { None } + fn device_mapper(&self) -> Option<&dyn DeviceMapper> { + None + } } #[async_trait::async_trait] diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index 7e607761e..72efeaeff 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -9,6 +9,7 @@ use super::{ AdapterActivationMixin, AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin, }; +use crate::device_map::DeviceMapper; use crate::lora::Ordering; use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig}; use crate::pipeline::get_chat_template; @@ -529,6 +530,9 @@ impl MetadataMixin for GGMLPipeline { fn get_metadata(&self) -> Arc { self.metadata.clone() } + fn device_mapper(&self) -> Option<&dyn DeviceMapper> { + None + } } #[async_trait::async_trait] diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 279b795e4..169c5202f 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -9,6 +9,7 @@ use super::{ AdapterActivationMixin, AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin, }; +use crate::device_map::DeviceMapper; use crate::gguf::{ get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion}, }; @@ -74,6 +75,7 @@ pub struct GGUFPipeline { model_id: String, non_granular_state: Option, metadata: Arc, + mapper: Box, } /// Loader for a GGUF model. @@ -332,7 +334,7 @@ impl Loader for GGUFLoader { silent: bool, mapper: DeviceMapMetadata, in_situ_quant: Option, - mut paged_attn_config: Option, + paged_attn_config: Option, ) -> Result>> { if in_situ_quant.is_some() { anyhow::bail!( @@ -353,9 +355,6 @@ impl Loader for GGUFLoader { self.get_id(), device.device_pretty_repr() ); - } else if paged_attn_config.is_some() { - warn!("Device mapping or device topology and PagedAttention are incompatible, disabling PagedAttention."); - paged_attn_config = None; } let mut readers = Vec::new(); @@ -409,7 +408,7 @@ impl Loader for GGUFLoader { // Base config (quantization only): let quant = ModelConfig::ParamsGGUF( model, - (device, mapper, self.config.topology.as_ref()).into(), + (device, mapper.clone(), self.config.topology.as_ref()).into(), if paged_attn_config.is_some() { AttentionImplementation::PagedAttention } else { @@ -455,6 +454,24 @@ impl Loader for GGUFLoader { _ => unreachable!(), }; + let num_hidden_layers = match model { + Model::Llama(ref model) => model.cache.normal().0.len(), + Model::Phi2(ref model) => model.cache.normal().0.len(), + Model::XLoraLlama(ref model) => model.cache.full().lock().len(), + Model::Phi3(ref model) => model.cache.normal().0.len(), + Model::XLoraPhi3(ref model) => model.cache.full().lock().len(), + Model::Starcoder2(ref model) => model.cache.normal().0.len(), + Model::Qwen2(ref model) => model.cache.normal().0.len(), + }; + + let mapper = + mapper.into_mapper(num_hidden_layers, device, self.config.topology.as_ref())?; + let mut layer_devices = Vec::new(); + for layer in 0..num_hidden_layers { + let device = mapper.device_for(layer, false).cloned(); + layer_devices.push(device); + } + let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config { let model_config: &dyn ModelConfigLike = &model_config_metadata; let cache_config = calculate_cache_config( @@ -465,8 +482,13 @@ impl Loader for GGUFLoader { model_config, device, )?; - let cache_engine = - CacheEngine::new(model_config, &cache_config, internal_dtype, device)?; + let cache_engine = CacheEngine::new( + model_config, + &cache_config, + internal_dtype, + device, + layer_devices, + )?; (Some(cache_config), Some(cache_engine)) } else { (None, None) @@ -548,6 +570,7 @@ impl Loader for GGUFLoader { prompt_batchsize: self.config.prompt_batchsize, model_metadata: Some(Arc::new(model_config_metadata)), }), + mapper, }))) } @@ -675,6 +698,9 @@ impl MetadataMixin for GGUFPipeline { fn get_metadata(&self) -> Arc { self.metadata.clone() } + fn device_mapper(&self) -> Option<&dyn DeviceMapper> { + Some(&*self.mapper) + } } #[async_trait::async_trait] diff --git a/mistralrs-core/src/pipeline/inputs_processor.rs b/mistralrs-core/src/pipeline/inputs_processor.rs index 3af61ccd5..aeb50146b 100644 --- a/mistralrs-core/src/pipeline/inputs_processor.rs +++ b/mistralrs-core/src/pipeline/inputs_processor.rs @@ -7,7 +7,7 @@ use candle_core::Device; use text_models_inputs_processor::PagedAttentionMeta; use tokenizers::Tokenizer; -use crate::sequence::Sequence; +use crate::{device_map::DeviceMapper, sequence::Sequence}; #[derive(PartialEq)] pub enum InputsProcessorType { @@ -40,6 +40,7 @@ pub trait InputsProcessor { other_config: Option>, paged_attn_metadata: Option>, prompt_batchsize: Option, + mapper: Option<&dyn DeviceMapper>, ) -> Box>>; fn get_type(&self) -> InputsProcessorType; @@ -48,13 +49,16 @@ pub trait InputsProcessor { // ========================= Test models input processor pub mod text_models_inputs_processor { - use std::{any::Any, fmt::Debug, iter::repeat, num::NonZeroUsize, sync::Arc}; + use std::{ + any::Any, collections::HashMap, fmt::Debug, iter::repeat, num::NonZeroUsize, sync::Arc, + }; use anyhow::Result; - use candle_core::{DType, Device, Tensor, WithDType}; + use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType}; use tokenizers::Tokenizer; use crate::{ + device_map::DeviceMapper, layers::set_use_matmul_via_f16, paged_attention::{BlockEngine, _PAD_SLOT_ID}, sequence::Sequence, @@ -89,9 +93,9 @@ pub mod text_models_inputs_processor { #[derive(Clone, Debug)] #[allow(dead_code)] pub struct PagedAttentionInputMetadata { - pub block_tables: Option, - pub context_lens: Option, - pub slot_mappings: Tensor, + pub block_tables: Option>, + pub context_lens: Option>, + pub slot_mappings: HashMap, pub max_context_len: Option, } @@ -103,7 +107,7 @@ pub mod text_models_inputs_processor { block_tables: None, context_lens: None, max_context_len: None, - slot_mappings: Tensor::new(&[0f32], dev)?, + slot_mappings: HashMap::from([(dev.location(), Tensor::new(&[0f32], dev)?)]), }) } } @@ -133,6 +137,7 @@ pub mod text_models_inputs_processor { // chunk_offset_toks is the number of tokens by which the tokens are offset, // chunk_offset_toks / prompt_batchsize = number of batches + #[allow(clippy::too_many_arguments)] pub fn make_prompt_chunk( chunk_offset_toks: usize, toks: Vec>, @@ -141,6 +146,7 @@ pub mod text_models_inputs_processor { last_n_context_len: Option<(usize, usize)>, return_raw_logits: bool, mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>, + mapper: Option<&dyn DeviceMapper>, ) -> Result { let max_len = toks .iter() @@ -311,10 +317,25 @@ pub mod text_models_inputs_processor { )? .reshape(((),))?; + // For device mapping, make a copy of each tensor for each device + let devices = mapper.unwrap().get_unique_devices(); + let mut slot_mappings_map = HashMap::new(); + let mut block_tables_map = HashMap::new(); + let mut context_lens_map = HashMap::new(); + + for device in devices { + slot_mappings_map + .insert(device.location(), slot_mappings.clone().to_device(&device)?); + block_tables_map + .insert(device.location(), block_tables.clone().to_device(&device)?); + context_lens_map + .insert(device.location(), context_lens.clone().to_device(&device)?); + } + Some(PagedAttentionInputMetadata { - slot_mappings, - block_tables: Some(block_tables), - context_lens: Some(context_lens), + slot_mappings: slot_mappings_map, + block_tables: Some(block_tables_map), + context_lens: Some(context_lens_map), max_context_len: Some(max_context_len), }) } else { @@ -342,6 +363,7 @@ pub mod text_models_inputs_processor { input_seqs: &[&mut Sequence], device: &Device, mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>, + mapper: Option<&dyn DeviceMapper>, ) -> Result { // Pad each sequence by the padding token to the max len. let mut seqs_tensors = Vec::new(); @@ -459,10 +481,25 @@ pub mod text_models_inputs_processor { device, )?; + // For device mapping, make a copy of each tensor for each device + let devices = mapper.unwrap().get_unique_devices(); + let mut slot_mappings_map = HashMap::new(); + let mut block_tables_map = HashMap::new(); + let mut context_lens_map = HashMap::new(); + + for device in devices { + slot_mappings_map + .insert(device.location(), slot_mappings.clone().to_device(&device)?); + block_tables_map + .insert(device.location(), block_tables.clone().to_device(&device)?); + context_lens_map + .insert(device.location(), context_lens.clone().to_device(&device)?); + } + Some(PagedAttentionInputMetadata { - slot_mappings, - block_tables: Some(block_tables), - context_lens: Some(context_lens), + slot_mappings: slot_mappings_map, + block_tables: Some(block_tables_map), + context_lens: Some(context_lens_map), max_context_len: Some(*max_context_len), }) } else { @@ -485,6 +522,7 @@ pub mod text_models_inputs_processor { }) } + #[allow(clippy::too_many_arguments)] pub(crate) fn get_prompt_input( toks: Vec>, input_seqs: &[&mut Sequence], @@ -493,6 +531,7 @@ pub mod text_models_inputs_processor { return_raw_logits: bool, mut paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>, prompt_batchsize: Option, + mapper: Option<&dyn DeviceMapper>, ) -> Box>> { if let (Some(prompt_batchsize), true) = (prompt_batchsize, paged_attn_metadata.is_none()) { let mut seq_chunks = Vec::new(); @@ -531,6 +570,7 @@ pub mod text_models_inputs_processor { last_n_context_len, return_raw_logits, paged_attn_metadata.as_deref_mut(), + mapper, ) .map(|inputs| InnerInputProcessorOutput { inputs, @@ -561,6 +601,7 @@ pub mod text_models_inputs_processor { last_n_context_len, return_raw_logits, paged_attn_metadata, + mapper, ) .map(|inputs| InnerInputProcessorOutput { inputs, @@ -580,6 +621,7 @@ pub mod text_models_inputs_processor { return_raw_logits: bool, paged_attn_metadata: Option<&mut PagedAttentionMeta<'_>>, prompt_batchsize: Option, + mapper: Option<&dyn DeviceMapper>, ) -> Box>> { if no_kv_cache { return get_prompt_input( @@ -590,16 +632,17 @@ pub mod text_models_inputs_processor { return_raw_logits, paged_attn_metadata, prompt_batchsize, + mapper, ); } Box::new(std::iter::once( - make_completion_chunk(toks, input_seqs, device, paged_attn_metadata).map(|inputs| { - InnerInputProcessorOutput { + make_completion_chunk(toks, input_seqs, device, paged_attn_metadata, mapper).map( + |inputs| InnerInputProcessorOutput { inputs, seq_indices: (0..input_seqs.len()).collect(), - } - }), + }, + ), )) } @@ -634,6 +677,7 @@ pub mod text_models_inputs_processor { _: Option>, mut paged_attn_metadata: Option>, prompt_batchsize: Option, + mapper: Option<&dyn DeviceMapper>, ) -> Box>> { if is_xlora && !is_prompt { Box::new( @@ -648,6 +692,7 @@ pub mod text_models_inputs_processor { return_raw_logits, paged_attn_metadata.as_mut(), prompt_batchsize, + mapper, ) .zip(get_completion_input( input_seqs @@ -661,6 +706,7 @@ pub mod text_models_inputs_processor { return_raw_logits, paged_attn_metadata.as_mut(), prompt_batchsize, + mapper, )) .map(|(prompt, completion)| { let InnerInputProcessorOutput { @@ -721,6 +767,7 @@ pub mod text_models_inputs_processor { return_raw_logits, paged_attn_metadata.as_mut(), prompt_batchsize, + mapper, ) .map(|metadata| { let InnerInputProcessorOutput { @@ -768,6 +815,7 @@ pub mod text_models_inputs_processor { return_raw_logits, paged_attn_metadata.as_mut(), prompt_batchsize, + mapper, ) .map(|metadata| { let InnerInputProcessorOutput { @@ -816,6 +864,7 @@ pub mod text_models_inputs_processor { return_raw_logits, paged_attn_metadata.as_mut(), prompt_batchsize, + mapper, ) .map(|metadata| { let InnerInputProcessorOutput { diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 9bb037845..6f93f005c 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -19,6 +19,7 @@ mod vision; pub use super::diffusion_models::DiffusionGenerationParams; use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult}; +use crate::device_map::DeviceMapper; use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike}; use crate::prefix_cacher_v2::PrefixCacheManagerV2; pub use amoe::{AnyMoeLoader, AnyMoePipeline}; @@ -148,6 +149,7 @@ pub trait MetadataMixin { fn name(&self) -> String; fn reset_non_granular_state(&self); fn get_metadata(&self) -> Arc; + fn device_mapper(&self) -> Option<&dyn DeviceMapper>; } /// Implemented by the base model of an AnyMoe. @@ -326,6 +328,7 @@ pub trait Pipeline: self.get_input_processor_config(), None, self.get_metadata().prompt_batchsize, + self.device_mapper(), ); let mut logits = vec![None; input_seqs.len()]; @@ -538,6 +541,7 @@ pub trait Pipeline: self.get_input_processor_config(), Some(metadata), self.get_metadata().prompt_batchsize, + self.device_mapper(), ); let mut logits = vec![None; input_seqs.len()]; diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index 3ab6630e1..c09618b69 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -16,6 +16,7 @@ use super::{ NormalLoaderType, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Starcoder2Loader, }; use crate::amoe::AnyMoeExpertType; +use crate::device_map::DeviceMapper; use crate::lora::Ordering; use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine}; use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig}; @@ -68,6 +69,7 @@ pub struct NormalPipeline { generation_config: Option, config: String, imatrix: Option, + mapper: Box, } /// A loader for a "normal" (non-quantized) model. @@ -271,7 +273,7 @@ impl Loader for NormalLoader { silent: bool, mapper: DeviceMapMetadata, in_situ_quant: Option, - mut paged_attn_config: Option, + paged_attn_config: Option, ) -> Result>> { let config = std::fs::read_to_string(paths.get_config_filename())?; // Otherwise, the device mapper will print it @@ -288,16 +290,22 @@ impl Loader for NormalLoader { self.get_id(), device.device_pretty_repr() ); - } else if paged_attn_config.is_some() { - warn!("Device mapping or device topology and PagedAttention are incompatible, disabling PagedAttention."); - paged_attn_config = None; } - + let pipeline_mapper = mapper.into_mapper( + self.inner.get_total_device_mapping_num_layers(&config)?, + device, + self.config.topology.as_ref(), + )?; let mapper = mapper.into_mapper( self.inner.get_total_device_mapping_num_layers(&config)?, device, self.config.topology.as_ref(), )?; + let mut layer_devices = Vec::new(); + for layer in 0..self.inner.get_total_device_mapping_num_layers(&config)? { + let device = mapper.device_for(layer, false).cloned(); + layer_devices.push(device); + } let dtype = mapper.get_min_dtype(dtype)?; info!( @@ -431,8 +439,16 @@ impl Loader for NormalLoader { let chunk_len = chunk.len(); let start = Instant::now(); - let inputs = - make_prompt_chunk(0, vec![chunk], &[0], &load_device, None, false, None)?; + let inputs = make_prompt_chunk( + 0, + vec![chunk], + &[0], + &load_device, + None, + false, + None, + Some(pipeline_mapper.as_ref()), + )?; let _ = model.forward( &inputs.input, &inputs.positions, @@ -523,7 +539,8 @@ impl Loader for NormalLoader { model.config(), device, )?; - let cache_engine = CacheEngine::new(model.config(), &cache_config, dtype, device)?; + let cache_engine = + CacheEngine::new(model.config(), &cache_config, dtype, device, layer_devices)?; (Some(cache_config), Some(cache_engine)) } else { (None, None) @@ -572,6 +589,7 @@ impl Loader for NormalLoader { generation_config: paths.get_gen_conf_filename().cloned(), config, imatrix: self.config.imatrix.clone(), + mapper: pipeline_mapper, }))) } @@ -689,6 +707,9 @@ impl MetadataMixin for NormalPipeline { fn get_metadata(&self) -> Arc { self.metadata.clone() } + fn device_mapper(&self) -> Option<&dyn DeviceMapper> { + Some(&*self.mapper) + } } #[async_trait::async_trait] diff --git a/mistralrs-core/src/pipeline/speculative.rs b/mistralrs-core/src/pipeline/speculative.rs index 7d2ba79da..9a3aeb340 100644 --- a/mistralrs-core/src/pipeline/speculative.rs +++ b/mistralrs-core/src/pipeline/speculative.rs @@ -13,6 +13,7 @@ use tokenizers::Tokenizer; use tracing::warn; use crate::{ + device_map::DeviceMapper, get_mut_arcmutex, pipeline::{ sampling::{ @@ -313,6 +314,9 @@ impl MetadataMixin for SpeculativePipeline { fn get_metadata(&self) -> Arc { self.metadata.clone() } + fn device_mapper(&self) -> Option<&dyn DeviceMapper> { + None + } } #[async_trait::async_trait] @@ -431,6 +435,7 @@ impl Pipeline for SpeculativePipeline { None, None, // TODO: get block tables/handle it None, // TODO: do we support??? + None, // TODO: device mapping ) .nth(0) .unwrap() @@ -505,6 +510,7 @@ impl Pipeline for SpeculativePipeline { None, None, // TODO: get block tables/handle it None, // TODO: do we support??? + None, // TODO: device mapping ) .nth(0) .unwrap() diff --git a/mistralrs-core/src/pipeline/vision.rs b/mistralrs-core/src/pipeline/vision.rs index c1c85cc16..1b2d63e71 100644 --- a/mistralrs-core/src/pipeline/vision.rs +++ b/mistralrs-core/src/pipeline/vision.rs @@ -11,6 +11,7 @@ use super::{ use super::{ Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Phi3VLoader, VisionLoaderType, }; +use crate::device_map::DeviceMapper; use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine}; use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig}; use crate::pipeline::llg::build_tok_env; @@ -58,6 +59,7 @@ pub struct VisionPipeline { topology: Option, silent: bool, prefixer: Arc, + mapper: Box, // For full UQFF serialization template_filename: Option, @@ -229,12 +231,21 @@ impl Loader for VisionLoader { self.inner .get_config_repr(&config, self.config.use_flash_attn)? ); - + let pipeline_mapper = mapper.into_mapper( + self.inner.get_total_device_mapping_num_layers(&config)?, + device, + self.config.topology.as_ref(), + )?; let mapper = mapper.into_mapper( self.inner.get_total_device_mapping_num_layers(&config)?, device, self.config.topology.as_ref(), )?; + let mut layer_devices = Vec::new(); + for layer in 0..self.inner.get_total_device_mapping_num_layers(&config)? { + let device = mapper.device_for(layer, false).cloned(); + layer_devices.push(device); + } let dtype = mapper.get_min_dtype(dtype)?; let mut loading_isq = in_situ_quant.is_some() || self.config.from_uqff.is_some(); @@ -349,7 +360,7 @@ impl Loader for VisionLoader { let start = Instant::now(); let inputs = - make_prompt_chunk(0, vec![chunk], &[0], &load_device, None, false, None)?; + make_prompt_chunk(0, vec![chunk], &[0], &load_device, None, false, None, None)?; let _ = model.forward( &inputs.input, None, // NOTE: We ONLY calibrate the text bits of these models!! @@ -435,7 +446,8 @@ impl Loader for VisionLoader { model.config(), device, )?; - let cache_engine = CacheEngine::new(model.config(), &cache_config, dtype, device)?; + let cache_engine = + CacheEngine::new(model.config(), &cache_config, dtype, device, layer_devices)?; (Some(cache_config), Some(cache_engine)) } else { (None, None) @@ -480,6 +492,7 @@ impl Loader for VisionLoader { config, processor_filename: paths.get_processor_config().clone(), preprocessor_filename: paths.get_preprocessor_config().clone(), + mapper: pipeline_mapper, }))) } @@ -590,6 +603,9 @@ impl MetadataMixin for VisionPipeline { fn tokenizer(&self) -> Option> { Some(self.tokenizer.clone()) } + fn device_mapper(&self) -> Option<&dyn DeviceMapper> { + Some(&*self.mapper) + } } #[async_trait::async_trait] diff --git a/mistralrs-core/src/vision_models/idefics2/idefics2_input_processor.rs b/mistralrs-core/src/vision_models/idefics2/idefics2_input_processor.rs index e18de7f5d..f016b443c 100644 --- a/mistralrs-core/src/vision_models/idefics2/idefics2_input_processor.rs +++ b/mistralrs-core/src/vision_models/idefics2/idefics2_input_processor.rs @@ -10,6 +10,7 @@ use tokenizers::Tokenizer; use tracing::warn; use crate::{ + device_map::DeviceMapper, pipeline::{ apply_chat_template, text_models_inputs_processor::{ @@ -141,6 +142,7 @@ impl InputsProcessor for Idefics2ImageProcessor { other_config: Option>, mut paged_attn_metadata: Option>, prompt_batchsize: Option, + _mapper: Option<&dyn DeviceMapper>, ) -> Box>> { if is_xlora { return Box::new(std::iter::once(Err(anyhow::Error::msg( @@ -181,6 +183,7 @@ impl InputsProcessor for Idefics2ImageProcessor { return_raw_logits, paged_attn_metadata.as_mut(), None, // TODO: evaluate if it is possible to batch this + None, ) .nth(0) .unwrap() @@ -198,6 +201,7 @@ impl InputsProcessor for Idefics2ImageProcessor { return_raw_logits, paged_attn_metadata.as_mut(), None, // TODO: evaluate if it is possible to batch this + None, ) .nth(0) .unwrap() diff --git a/mistralrs-core/src/vision_models/idefics3/inputs_processor.rs b/mistralrs-core/src/vision_models/idefics3/inputs_processor.rs index 129685063..d40ea4c7c 100644 --- a/mistralrs-core/src/vision_models/idefics3/inputs_processor.rs +++ b/mistralrs-core/src/vision_models/idefics3/inputs_processor.rs @@ -9,6 +9,7 @@ use tokenizers::Tokenizer; use tracing::warn; use crate::{ + device_map::DeviceMapper, pipeline::{ text_models_inputs_processor::{ self, get_completion_input, get_prompt_input, PagedAttentionMeta, @@ -112,6 +113,7 @@ impl InputsProcessor for Idefics3ImageProcessor { other_config: Option>, mut paged_attn_metadata: Option>, prompt_batchsize: Option, + _mapper: Option<&dyn DeviceMapper>, ) -> Box>> { if is_xlora { return Box::new(std::iter::once(Err(anyhow::Error::msg( @@ -157,6 +159,7 @@ impl InputsProcessor for Idefics3ImageProcessor { return_raw_logits, paged_attn_metadata.as_mut(), None, // TODO: evaluate if it is possible to batch this + None, ) .nth(0) .unwrap() @@ -174,6 +177,7 @@ impl InputsProcessor for Idefics3ImageProcessor { return_raw_logits, paged_attn_metadata.as_mut(), None, // TODO: evaluate if it is possible to batch this + None, ) .nth(0) .unwrap() diff --git a/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs b/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs index 5820ae736..e9a13bb5b 100644 --- a/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs @@ -14,6 +14,7 @@ use tracing::warn; use super::llava15::LLaVAVisionSpecificArgs; use super::utils::{expand2square, LLaVAImageProcessor}; +use crate::device_map::DeviceMapper; use crate::pipeline::text_models_inputs_processor::{ get_completion_input, get_prompt_input, PagedAttentionMeta, }; @@ -87,6 +88,7 @@ impl InputsProcessor for LLaVAInputProcessor { other_config: Option>, mut paged_attn_metadata: Option>, prompt_batchsize: Option, + _mapper: Option<&dyn DeviceMapper>, ) -> Box>> { if is_xlora { return Box::new(std::iter::once(Err(anyhow::Error::msg( @@ -167,6 +169,7 @@ impl InputsProcessor for LLaVAInputProcessor { other_config, paged_attn_metadata, None, // TODO + None, ) .map(|metadata| { let InputProcessorOutput { @@ -283,6 +286,7 @@ impl InputsProcessor for LLaVAInputProcessor { return_raw_logits, paged_attn_metadata.as_mut(), None, // TODO: evaluate if it is possible to batch this + None, ) } else { get_completion_input( @@ -294,6 +298,7 @@ impl InputsProcessor for LLaVAInputProcessor { return_raw_logits, paged_attn_metadata.as_mut(), None, // TODO: evaluate if it is possible to batch this + None, ) }; diff --git a/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs b/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs index 833a928b5..34c906994 100644 --- a/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs @@ -11,6 +11,7 @@ use regex_automata::meta::Regex; use tokenizers::Tokenizer; use tracing::warn; +use crate::device_map::DeviceMapper; use crate::pipeline::text_models_inputs_processor::{ get_completion_input, get_prompt_input, PagedAttentionMeta, }; @@ -94,6 +95,7 @@ impl InputsProcessor for LLaVANextInputProcessor { other_config: Option>, mut paged_attn_metadata: Option>, prompt_batchsize: Option, + _mapper: Option<&dyn DeviceMapper>, ) -> Box>> { if is_xlora { return Box::new(std::iter::once(Err(anyhow::Error::msg( @@ -193,6 +195,7 @@ impl InputsProcessor for LLaVANextInputProcessor { other_config, paged_attn_metadata, None, // TODO + None, ) .map(|metadata| { let InputProcessorOutput { @@ -327,6 +330,7 @@ impl InputsProcessor for LLaVANextInputProcessor { return_raw_logits, paged_attn_metadata.as_mut(), None, // TODO: evaluate if it is possible to batch this + None, ) } else { get_completion_input( @@ -338,6 +342,7 @@ impl InputsProcessor for LLaVANextInputProcessor { return_raw_logits, paged_attn_metadata.as_mut(), None, // TODO: evaluate if it is possible to batch this + None, ) }; diff --git a/mistralrs-core/src/vision_models/mllama/inputs_processor.rs b/mistralrs-core/src/vision_models/mllama/inputs_processor.rs index d97601ed2..af75f5a17 100644 --- a/mistralrs-core/src/vision_models/mllama/inputs_processor.rs +++ b/mistralrs-core/src/vision_models/mllama/inputs_processor.rs @@ -18,6 +18,7 @@ use tokenizers::Tokenizer; use tracing::warn; use crate::{ + device_map::DeviceMapper, pipeline::{ text_models_inputs_processor::{ self, get_completion_input, get_prompt_input, PagedAttentionMeta, @@ -181,6 +182,7 @@ impl InputsProcessor for MLlamaImageProcessor { other_config: Option>, mut paged_attn_metadata: Option>, prompt_batchsize: Option, + _mapper: Option<&dyn DeviceMapper>, ) -> Box>> { if is_xlora { return Box::new(std::iter::once(Err(anyhow::Error::msg( @@ -226,6 +228,7 @@ impl InputsProcessor for MLlamaImageProcessor { return_raw_logits, paged_attn_metadata.as_mut(), None, // TODO: evaluate if it is possible to batch this + None, ) .nth(0) .unwrap() @@ -243,6 +246,7 @@ impl InputsProcessor for MLlamaImageProcessor { return_raw_logits, paged_attn_metadata.as_mut(), None, // TODO: evaluate if it is possible to batch this + None, ) .nth(0) .unwrap() diff --git a/mistralrs-core/src/vision_models/phi3/phi3_inputs_processor.rs b/mistralrs-core/src/vision_models/phi3/phi3_inputs_processor.rs index e31c6c74b..d008e5f1d 100644 --- a/mistralrs-core/src/vision_models/phi3/phi3_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/phi3/phi3_inputs_processor.rs @@ -11,6 +11,7 @@ use tokenizers::Tokenizer; use tracing::warn; use crate::{ + device_map::DeviceMapper, pipeline::{ text_models_inputs_processor::{ self, get_completion_input, get_prompt_input, PagedAttentionMeta, @@ -81,6 +82,7 @@ impl InputsProcessor for Phi3InputsProcessor { other_config: Option>, mut paged_attn_metadata: Option>, prompt_batchsize: Option, + _mapper: Option<&dyn DeviceMapper>, ) -> Box>> { if is_xlora { return Box::new(std::iter::once(Err(anyhow::Error::msg( @@ -169,6 +171,7 @@ impl InputsProcessor for Phi3InputsProcessor { other_config, paged_attn_metadata, None, // TODO + None, ) .map(|metadata| { let InputProcessorOutput { @@ -320,6 +323,7 @@ impl InputsProcessor for Phi3InputsProcessor { return_raw_logits, paged_attn_metadata.as_mut(), None, // TODO: evaluate if it is possible to batch this + None, ) } else { get_completion_input( @@ -331,6 +335,7 @@ impl InputsProcessor for Phi3InputsProcessor { return_raw_logits, paged_attn_metadata.as_mut(), None, // TODO: evaluate if it is possible to batch this + None, ) }; diff --git a/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs b/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs index 513a1c862..664ec9f18 100644 --- a/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs +++ b/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs @@ -14,6 +14,7 @@ use tokenizers::Tokenizer; use tracing::warn; use crate::{ + device_map::DeviceMapper, pipeline::{ text_models_inputs_processor::{ self, get_completion_input, get_prompt_input, PagedAttentionMeta, @@ -133,6 +134,7 @@ impl InputsProcessor for Qwen2VLImageProcessor { other_config: Option>, mut paged_attn_metadata: Option>, prompt_batchsize: Option, + _mapper: Option<&dyn DeviceMapper>, ) -> Box>> { if is_xlora { return Box::new(std::iter::once(Err(anyhow::Error::msg( @@ -183,6 +185,7 @@ impl InputsProcessor for Qwen2VLImageProcessor { return_raw_logits, paged_attn_metadata.as_mut(), None, // TODO: evaluate if it is possible to batch this + None, ) .nth(0) .unwrap() @@ -200,6 +203,7 @@ impl InputsProcessor for Qwen2VLImageProcessor { return_raw_logits, paged_attn_metadata.as_mut(), None, // TODO: evaluate if it is possible to batch this + None, ) .nth(0) .unwrap()