Skip to content

Commit

Permalink
Support device mapping for Paged Attention (#1011)
Browse files Browse the repository at this point in the history
* Move start_offsets_kernel to correct device

* Move start_offsets_kernel to correct device

* Move start_offsets_kernel to correct device

* Move start_offsets_kernel to correct device

* Move start_offsets_kernel to correct device

* Move start_offsets_kernel to correct device

* Move start_offsets_kernel to correct device

* Move start_offsets_kernel to correct device

* Move start_offsets_kernel to correct device

* Update starcoder2.rs

* Support device mapping

* Support device mapping

* Support device mapping

* Support device mapping

* Support device mapping

* format

* Support device mapping

* remove mut

* remove mut

* Add get_unique_devices method

* Move tensor for device mapping

* Add DeviceMapper

* Fix wrong RotaryEmbedding import

* Fix wrong RotaryEmbedding import

* Remove unecessary tensor copies

* Add DeviceMapper

* Add DeviceMapper

* Add DeviceMapper

* Add device mapping

* Create tensor copies for each device for pa

* Add device mapper

* Add device mapper

* Add device mapper

* Add device mapper

* Add device mapper

* Add device mapper

* Add device mapper

* Add device mapper

* Add device mapper

* Add device mapper

* add device mapper

* Remove unecessary tensor move

* Remove unecessary tensor move

* Remove unecessary tensor move

* Remove unecessary tensor move

* Remove unecessary tensor move

* Remove unecessary tensor move

* Remove unecessary tensor move

* Remove unecessary tensor move

* format

* format

* format

* clippy

* format
  • Loading branch information
cdoko authored Jan 1, 2025
1 parent 1880c0b commit c345954
Show file tree
Hide file tree
Showing 24 changed files with 269 additions and 59 deletions.
12 changes: 12 additions & 0 deletions mistralrs-core/src/device_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ pub trait DeviceMapper: Debug {
) -> VarBuilder<'a>;
/// If ISQ layer, then do not change the device (return None). *They will do it later in NormalModel::quantize*
fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device>;
fn get_unique_devices(&self) -> Vec<Device>;
/// If ISQ layer, then do not change the device (return None). *They will do it later in NormalModel::quantize*
fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor>;
/// Set non mapped layer device. This is for ISQ + device mapping support
Expand Down Expand Up @@ -186,6 +187,14 @@ impl DeviceMapper for LayerDeviceMapper {
}
self.mappings.get(layer)
}
fn get_unique_devices(&self) -> Vec<Device> {
self.mappings.iter().fold(Vec::new(), |mut acc, device| {
if !acc.iter().any(|d| d.same_device(device)) {
acc.push(device.clone());
}
acc
})
}
fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
if loading_isq {
x.to_device(&Device::Cpu)
Expand Down Expand Up @@ -234,6 +243,9 @@ impl DeviceMapper for DummyDeviceMapper {
}
None
}
fn get_unique_devices(&self) -> Vec<Device> {
vec![self.nm_device.clone()]
}
fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
if loading_isq {
x.to_device(&Device::Cpu)
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/diffusion_models/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use indexmap::IndexMap;
use tokenizers::Tokenizer;

use crate::{
device_map::DeviceMapper,
pipeline::{
text_models_inputs_processor::PagedAttentionMeta, InputProcessorOutput, InputsProcessor,
InputsProcessorType, MessagesAction, Processor,
Expand Down Expand Up @@ -69,6 +70,7 @@ impl InputsProcessor for DiffusionInputsProcessor {
_other_config: Option<Arc<dyn Any>>,
_paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
prompt_batchsize: Option<NonZeroUsize>,
_mapper: Option<&dyn DeviceMapper>,
) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
let mut make_value = if prompt_batchsize.is_some() {
return Box::new(std::iter::once(Err(anyhow::Error::msg(
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/dummy_paged_attention/cache_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ impl CacheEngine {
_cache_config: &CacheConfig,
_dtype: DType,
_device: &Device,
_layer_devices: Vec<Option<Device>>,
) -> Result<Self> {
Ok(Self {
dummy_cache: Arc::new(Mutex::new(Vec::new())),
Expand Down
10 changes: 8 additions & 2 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,9 @@ impl Llama3RotaryEmbedding {
k: &mut Tensor,
b_sz: usize,
) -> Result<()> {
// Needed for device mapping
let positions_kernel = positions_kernel.to_device(q.device())?;

match self {
Self::Llama3 { sin, cos, is_gptx } => {
let (b_sz_seq_len, h, n_embd) = q.dims3()?;
Expand Down Expand Up @@ -646,7 +649,7 @@ impl Llama3RotaryEmbedding {
*k = Tensor::cat(&k_embeds, 0)?;
Ok(())
}
Self::Default(rope) => rope.forward(positions, positions_kernel, q, k, b_sz),
Self::Default(rope) => rope.forward(positions, &positions_kernel, q, k, b_sz),
}
}
}
Expand Down Expand Up @@ -935,7 +938,10 @@ impl RotaryEmbedding {
k: &mut Tensor,
b_sz: usize,
) -> Result<()> {
self.0.forward(positions, positions_kernel, q, k, b_sz)
// Needed for device mapping
let positions_kernel = positions_kernel.to_device(q.device())?;

self.0.forward(positions, &positions_kernel, q, k, b_sz)
}
}

Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ use std::sync::Arc;
use candle_core::quantized::ggml_file;
use candle_core::quantized::QTensor;
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, RotaryEmbedding};
use candle_nn::{Embedding, Module};
use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};

use crate::attention::SdpaParams;
use crate::device_map::DeviceMapper;
use crate::gguf::Content;
use crate::layers::{CausalMasker, MatMul, QRmsNorm, Sdpa};
use crate::layers::{CausalMasker, MatMul, QRmsNorm, RotaryEmbedding, Sdpa};
use crate::layers_masker::PastKvLenCache;
use crate::paged_attention::{AttentionImplementation, PagedAttention};
use crate::pipeline::extract_logits;
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/models/quantized_qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use std::collections::HashMap;
use std::sync::Arc;

use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, RotaryEmbedding};
use candle_nn::{Embedding, Module};
use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};

use crate::attention::SdpaParams;
use crate::device_map::DeviceMapper;
use crate::gguf::Content;
use crate::layers::{CausalMasker, MatMul, QRmsNorm, Sdpa};
use crate::layers::{CausalMasker, MatMul, QRmsNorm, RotaryEmbedding, Sdpa};
use crate::layers_masker::PastKvLenCache;
use crate::paged_attention::{AttentionImplementation, PagedAttention};
use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
Expand Down
7 changes: 6 additions & 1 deletion mistralrs-core/src/paged_attention/cache_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ impl CacheEngine {
cache_config: &CacheConfig,
dtype: DType,
device: &Device,
layer_devices: Vec<Option<Device>>,
) -> Result<Self> {
Ok(Self {
gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache(
model_config,
cache_config,
dtype,
device,
layer_devices,
)?)),
cpu_cache: Self::allocate_cpu_cache(model_config, cache_config, dtype, device)?,
num_layers: model_config.num_layers(),
Expand All @@ -55,13 +57,16 @@ impl CacheEngine {
cache_config: &CacheConfig,
dtype: DType,
device: &Device,
layer_devices: Vec<Option<Device>>,
) -> Result<Vec<KVCache>> {
let key_block_shape =
Self::calculate_key_block_shape(model_config, dtype, cache_config.block_size);
let value_block_shape =
Self::calculate_value_block_shape(model_config, cache_config.block_size);
let mut gpu_cache = Vec::new();
for _ in 0..model_config.num_layers() {

for i in 0..model_config.num_layers() {
let device = layer_devices[i].as_ref().unwrap_or(device);
let key_blocks = Tensor::zeros(
(
cache_config.num_gpu_blocks,
Expand Down
52 changes: 35 additions & 17 deletions mistralrs-core/src/paged_attention/layers/paged_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,34 @@ impl PagedAttention {
input_metadata: &mut PagedAttentionInputMetadata,
softcapping: Option<f64>,
) -> Result<Tensor> {
let dims = input_metadata.slot_mappings.dims();
let slot_mapping = input_metadata
.slot_mappings
.get(&query.device().location())
.unwrap();
let dims = slot_mapping.dims();
let slot_mapping = if dims.len() > 1 {
input_metadata
.slot_mappings
.flatten(0, input_metadata.slot_mappings.dims().len())?
&slot_mapping.flatten(0, dims.len())?
} else {
input_metadata.slot_mappings.clone()
slot_mapping
};

let block_tables = input_metadata
.block_tables
.as_ref()
.unwrap()
.get(&query.device().location())
.unwrap();
let context_lens = input_metadata
.context_lens
.as_ref()
.unwrap()
.get(&query.device().location())
.unwrap();

let alibi_slopes = if let Some(alibi_slopes) = self.alibi_slopes.as_ref() {
Some(alibi_slopes.to_device(query.device())?)
} else {
None
};

let (batch_size, attention_heads, seq_len, head_size) = query.shape().dims4()?;
Expand All @@ -80,7 +101,7 @@ impl PagedAttention {
query,
key,
value,
Some(mask),
Some(&mask),
None,
&SdpaParams {
n_kv_groups: self.n_kv_groups,
Expand All @@ -92,7 +113,7 @@ impl PagedAttention {
)?),
};

// // paged-attn expects [batch_size, num_tokens, num_heads, head_size]
// paged-attn expects [batch_size, num_tokens, num_heads, head_size]
let (query, key, value) = if seq_len > 1 {
let q = query
.transpose(1, 2)?
Expand All @@ -105,7 +126,7 @@ impl PagedAttention {
.reshape(((), key_value_heads, head_size))?;
(q, k, v)
} else {
//avoid unnecessary transpose for decoding
// avoid unnecessary transpose for decoding
let q = query.reshape(((), attention_heads, head_size))?;
let k = key.reshape(((), key_value_heads, head_size))?;
let v = value.reshape(((), key_value_heads, head_size))?;
Expand All @@ -123,15 +144,14 @@ impl PagedAttention {
&value,
key_cache.as_mut().unwrap(),
value_cache.as_mut().unwrap(),
&slot_mapping,
slot_mapping,
)?;
}

if let Some(att) = att {
// Return result in prefill
return Ok(att);
}

// Args:
// output: shape = [num_generation_tokens, num_heads, head_size]
//
Expand All @@ -147,18 +167,16 @@ impl PagedAttention {
//
// alibi_slopes: shape = [num_heads]
#[allow(clippy::cast_possible_truncation)]
let res = paged_attention(
paged_attention(
&query,
key_cache.as_ref().unwrap(),
value_cache.as_ref().unwrap(),
input_metadata.block_tables.as_ref().unwrap(),
input_metadata.context_lens.as_ref().unwrap(),
self.alibi_slopes.as_ref(),
block_tables,
context_lens,
alibi_slopes.as_ref(),
input_metadata.max_context_len.unwrap(),
self.scale,
softcapping.unwrap_or(1.0f64) as f32,
)?;

Ok(res)
)
}
}
5 changes: 5 additions & 0 deletions mistralrs-core/src/pipeline/amoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use tracing::{info, warn};

use crate::{
amoe::{AnyMoeConfig, AnyMoeTrainingInputRow, AnyMoeTrainingInputs, AnyMoeTrainingResult},
device_map::DeviceMapper,
get_mut_arcmutex,
prefix_cacher_v2::PrefixCacheManagerV2,
sampler::Sampler,
Expand Down Expand Up @@ -244,6 +245,9 @@ impl MetadataMixin for AnyMoePipeline {
fn tokenizer(&self) -> Option<Arc<tokenizers::Tokenizer>> {
get_mut_arcmutex!(self.target).tokenizer()
}
fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
None
}
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -469,6 +473,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
input_processor_cfg.clone(),
None, // TODO: get block tables/handle it for PagedAttention
None, // TODO: prompt chunking doesn't work.
None,
)
.nth(0)
.unwrap();
Expand Down
4 changes: 4 additions & 0 deletions mistralrs-core/src/pipeline/diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::{
GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
PreProcessingMixin, Processor, TokenSource,
};
use crate::device_map::DeviceMapper;
use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
use crate::paged_attention::AttentionImplementation;
use crate::pipeline::ChatTemplate;
Expand Down Expand Up @@ -296,6 +297,9 @@ impl MetadataMixin for DiffusionPipeline {
fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
None
}
fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
None
}
}

#[async_trait::async_trait]
Expand Down
4 changes: 4 additions & 0 deletions mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use super::{
AdapterActivationMixin, AnyMoePipelineMixin, CacheManagerMixin, EitherCache,
ForwardInputsResult, IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
};
use crate::device_map::DeviceMapper;
use crate::lora::Ordering;
use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
use crate::pipeline::get_chat_template;
Expand Down Expand Up @@ -529,6 +530,9 @@ impl MetadataMixin for GGMLPipeline {
fn get_metadata(&self) -> Arc<GeneralMetadata> {
self.metadata.clone()
}
fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
None
}
}

#[async_trait::async_trait]
Expand Down
Loading

0 comments on commit c345954

Please sign in to comment.