Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support device mapping for Paged Attention #1011

Merged
merged 55 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
acea1fd
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
6032955
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
946dfd9
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
8a134b9
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
6f02469
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
86d026a
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
d54e767
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
819278c
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
007f2db
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
e7d2d80
Update starcoder2.rs
cdoko Dec 28, 2024
db0cdc5
Support device mapping
cdoko Dec 28, 2024
05ed5fe
Support device mapping
cdoko Dec 28, 2024
937319b
Support device mapping
cdoko Dec 28, 2024
047ca07
Support device mapping
cdoko Dec 28, 2024
882f4e7
Support device mapping
cdoko Dec 28, 2024
a51bca6
format
cdoko Dec 28, 2024
db78205
Support device mapping
cdoko Dec 28, 2024
e6324b4
remove mut
cdoko Dec 28, 2024
8fadbbc
remove mut
cdoko Dec 28, 2024
9d9918d
Merge branch 'master' into device-mapping-paged-attn
cdoko Dec 31, 2024
895d0a9
Add get_unique_devices method
cdoko Dec 31, 2024
7d46900
Move tensor for device mapping
cdoko Dec 31, 2024
aa90ef2
Add DeviceMapper
cdoko Dec 31, 2024
8fc40fc
Fix wrong RotaryEmbedding import
cdoko Dec 31, 2024
ad66e29
Fix wrong RotaryEmbedding import
cdoko Dec 31, 2024
e0719f9
Remove unecessary tensor copies
cdoko Dec 31, 2024
cffeaaa
Add DeviceMapper
cdoko Dec 31, 2024
f0f3ac1
Add DeviceMapper
cdoko Dec 31, 2024
e935067
Add DeviceMapper
cdoko Dec 31, 2024
a833acf
Add device mapping
cdoko Dec 31, 2024
efbd6f4
Create tensor copies for each device for pa
cdoko Dec 31, 2024
8a0177a
Add device mapper
cdoko Dec 31, 2024
b614be9
Add device mapper
cdoko Dec 31, 2024
30618da
Add device mapper
cdoko Dec 31, 2024
0215e86
Add device mapper
cdoko Dec 31, 2024
44e0559
Add device mapper
cdoko Dec 31, 2024
095e28a
Add device mapper
cdoko Dec 31, 2024
80eb294
Add device mapper
cdoko Dec 31, 2024
ef7ee66
Add device mapper
cdoko Dec 31, 2024
f269c55
Add device mapper
cdoko Dec 31, 2024
587b4f7
Add device mapper
cdoko Dec 31, 2024
36d89c9
add device mapper
cdoko Dec 31, 2024
17f8065
Remove unecessary tensor move
cdoko Dec 31, 2024
3ca105a
Remove unecessary tensor move
cdoko Dec 31, 2024
40706f2
Remove unecessary tensor move
cdoko Dec 31, 2024
62d2126
Remove unecessary tensor move
cdoko Dec 31, 2024
78189c9
Remove unecessary tensor move
cdoko Dec 31, 2024
6724ee1
Remove unecessary tensor move
cdoko Dec 31, 2024
6ca0625
Remove unecessary tensor move
cdoko Dec 31, 2024
d3b4dae
Remove unecessary tensor move
cdoko Dec 31, 2024
ae3f53e
format
cdoko Dec 31, 2024
3bf680d
format
cdoko Dec 31, 2024
7560df9
format
cdoko Dec 31, 2024
83cf77d
clippy
cdoko Dec 31, 2024
45aad07
format
cdoko Dec 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mistralrs-core/src/device_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ pub trait DeviceMapper: Debug {
) -> VarBuilder<'a>;
/// If ISQ layer, then do not change the device (return None). *They will do it later in NormalModel::quantize*
fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device>;
fn get_unique_devices(&self) -> Vec<Device>;
/// If ISQ layer, then do not change the device (return None). *They will do it later in NormalModel::quantize*
fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor>;
/// Set non mapped layer device. This is for ISQ + device mapping support
Expand Down Expand Up @@ -186,6 +187,14 @@ impl DeviceMapper for LayerDeviceMapper {
}
self.mappings.get(layer)
}
fn get_unique_devices(&self) -> Vec<Device> {
self.mappings.iter().fold(Vec::new(), |mut acc, device| {
if !acc.iter().any(|d| d.same_device(device)) {
acc.push(device.clone());
}
acc
})
}
fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
if loading_isq {
x.to_device(&Device::Cpu)
Expand Down Expand Up @@ -234,6 +243,9 @@ impl DeviceMapper for DummyDeviceMapper {
}
None
}
fn get_unique_devices(&self) -> Vec<Device> {
vec![self.nm_device.clone()]
}
fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
if loading_isq {
x.to_device(&Device::Cpu)
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/diffusion_models/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use indexmap::IndexMap;
use tokenizers::Tokenizer;

use crate::{
device_map::DeviceMapper,
pipeline::{
text_models_inputs_processor::PagedAttentionMeta, InputProcessorOutput, InputsProcessor,
InputsProcessorType, MessagesAction, Processor,
Expand Down Expand Up @@ -69,6 +70,7 @@ impl InputsProcessor for DiffusionInputsProcessor {
_other_config: Option<Arc<dyn Any>>,
_paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
prompt_batchsize: Option<NonZeroUsize>,
_mapper: Option<&dyn DeviceMapper>,
) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
let mut make_value = if prompt_batchsize.is_some() {
return Box::new(std::iter::once(Err(anyhow::Error::msg(
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/dummy_paged_attention/cache_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ impl CacheEngine {
_cache_config: &CacheConfig,
_dtype: DType,
_device: &Device,
_layer_devices: Vec<Option<Device>>,
) -> Result<Self> {
Ok(Self {
dummy_cache: Arc::new(Mutex::new(Vec::new())),
Expand Down
10 changes: 8 additions & 2 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,9 @@ impl Llama3RotaryEmbedding {
k: &mut Tensor,
b_sz: usize,
) -> Result<()> {
// Needed for device mapping
let positions_kernel = positions_kernel.to_device(q.device())?;

match self {
Self::Llama3 { sin, cos, is_gptx } => {
let (b_sz_seq_len, h, n_embd) = q.dims3()?;
Expand Down Expand Up @@ -646,7 +649,7 @@ impl Llama3RotaryEmbedding {
*k = Tensor::cat(&k_embeds, 0)?;
Ok(())
}
Self::Default(rope) => rope.forward(positions, positions_kernel, q, k, b_sz),
Self::Default(rope) => rope.forward(positions, &positions_kernel, q, k, b_sz),
}
}
}
Expand Down Expand Up @@ -935,7 +938,10 @@ impl RotaryEmbedding {
k: &mut Tensor,
b_sz: usize,
) -> Result<()> {
self.0.forward(positions, positions_kernel, q, k, b_sz)
// Needed for device mapping
let positions_kernel = positions_kernel.to_device(q.device())?;

self.0.forward(positions, &positions_kernel, q, k, b_sz)
}
}

Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ use std::sync::Arc;
use candle_core::quantized::ggml_file;
use candle_core::quantized::QTensor;
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, RotaryEmbedding};
use candle_nn::{Embedding, Module};
use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};

use crate::attention::SdpaParams;
use crate::device_map::DeviceMapper;
use crate::gguf::Content;
use crate::layers::{CausalMasker, MatMul, QRmsNorm, Sdpa};
use crate::layers::{CausalMasker, MatMul, QRmsNorm, RotaryEmbedding, Sdpa};
use crate::layers_masker::PastKvLenCache;
use crate::paged_attention::{AttentionImplementation, PagedAttention};
use crate::pipeline::extract_logits;
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/models/quantized_qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use std::collections::HashMap;
use std::sync::Arc;

use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, RotaryEmbedding};
use candle_nn::{Embedding, Module};
use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};

use crate::attention::SdpaParams;
use crate::device_map::DeviceMapper;
use crate::gguf::Content;
use crate::layers::{CausalMasker, MatMul, QRmsNorm, Sdpa};
use crate::layers::{CausalMasker, MatMul, QRmsNorm, RotaryEmbedding, Sdpa};
use crate::layers_masker::PastKvLenCache;
use crate::paged_attention::{AttentionImplementation, PagedAttention};
use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
Expand Down
7 changes: 6 additions & 1 deletion mistralrs-core/src/paged_attention/cache_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ impl CacheEngine {
cache_config: &CacheConfig,
dtype: DType,
device: &Device,
layer_devices: Vec<Option<Device>>,
) -> Result<Self> {
Ok(Self {
gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache(
model_config,
cache_config,
dtype,
device,
layer_devices,
)?)),
cpu_cache: Self::allocate_cpu_cache(model_config, cache_config, dtype, device)?,
num_layers: model_config.num_layers(),
Expand All @@ -55,13 +57,16 @@ impl CacheEngine {
cache_config: &CacheConfig,
dtype: DType,
device: &Device,
layer_devices: Vec<Option<Device>>,
) -> Result<Vec<KVCache>> {
let key_block_shape =
Self::calculate_key_block_shape(model_config, dtype, cache_config.block_size);
let value_block_shape =
Self::calculate_value_block_shape(model_config, cache_config.block_size);
let mut gpu_cache = Vec::new();
for _ in 0..model_config.num_layers() {

for i in 0..model_config.num_layers() {
let device = layer_devices[i].as_ref().unwrap_or(device);
let key_blocks = Tensor::zeros(
(
cache_config.num_gpu_blocks,
Expand Down
52 changes: 35 additions & 17 deletions mistralrs-core/src/paged_attention/layers/paged_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,34 @@ impl PagedAttention {
input_metadata: &mut PagedAttentionInputMetadata,
softcapping: Option<f64>,
) -> Result<Tensor> {
let dims = input_metadata.slot_mappings.dims();
let slot_mapping = input_metadata
.slot_mappings
.get(&query.device().location())
.unwrap();
let dims = slot_mapping.dims();
let slot_mapping = if dims.len() > 1 {
input_metadata
.slot_mappings
.flatten(0, input_metadata.slot_mappings.dims().len())?
&slot_mapping.flatten(0, dims.len())?
} else {
input_metadata.slot_mappings.clone()
slot_mapping
};

let block_tables = input_metadata
.block_tables
.as_ref()
.unwrap()
.get(&query.device().location())
.unwrap();
let context_lens = input_metadata
.context_lens
.as_ref()
.unwrap()
.get(&query.device().location())
.unwrap();

let alibi_slopes = if let Some(alibi_slopes) = self.alibi_slopes.as_ref() {
Some(alibi_slopes.to_device(query.device())?)
} else {
None
};

let (batch_size, attention_heads, seq_len, head_size) = query.shape().dims4()?;
Expand All @@ -80,7 +101,7 @@ impl PagedAttention {
query,
key,
value,
Some(mask),
Some(&mask),
None,
&SdpaParams {
n_kv_groups: self.n_kv_groups,
Expand All @@ -92,7 +113,7 @@ impl PagedAttention {
)?),
};

// // paged-attn expects [batch_size, num_tokens, num_heads, head_size]
// paged-attn expects [batch_size, num_tokens, num_heads, head_size]
let (query, key, value) = if seq_len > 1 {
let q = query
.transpose(1, 2)?
Expand All @@ -105,7 +126,7 @@ impl PagedAttention {
.reshape(((), key_value_heads, head_size))?;
(q, k, v)
} else {
//avoid unnecessary transpose for decoding
// avoid unnecessary transpose for decoding
let q = query.reshape(((), attention_heads, head_size))?;
let k = key.reshape(((), key_value_heads, head_size))?;
let v = value.reshape(((), key_value_heads, head_size))?;
Expand All @@ -123,15 +144,14 @@ impl PagedAttention {
&value,
key_cache.as_mut().unwrap(),
value_cache.as_mut().unwrap(),
&slot_mapping,
slot_mapping,
)?;
}

if let Some(att) = att {
// Return result in prefill
return Ok(att);
}

// Args:
// output: shape = [num_generation_tokens, num_heads, head_size]
//
Expand All @@ -147,18 +167,16 @@ impl PagedAttention {
//
// alibi_slopes: shape = [num_heads]
#[allow(clippy::cast_possible_truncation)]
let res = paged_attention(
paged_attention(
&query,
key_cache.as_ref().unwrap(),
value_cache.as_ref().unwrap(),
input_metadata.block_tables.as_ref().unwrap(),
input_metadata.context_lens.as_ref().unwrap(),
self.alibi_slopes.as_ref(),
block_tables,
context_lens,
alibi_slopes.as_ref(),
input_metadata.max_context_len.unwrap(),
self.scale,
softcapping.unwrap_or(1.0f64) as f32,
)?;

Ok(res)
)
}
}
5 changes: 5 additions & 0 deletions mistralrs-core/src/pipeline/amoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use tracing::{info, warn};

use crate::{
amoe::{AnyMoeConfig, AnyMoeTrainingInputRow, AnyMoeTrainingInputs, AnyMoeTrainingResult},
device_map::DeviceMapper,
get_mut_arcmutex,
prefix_cacher_v2::PrefixCacheManagerV2,
sampler::Sampler,
Expand Down Expand Up @@ -244,6 +245,9 @@ impl MetadataMixin for AnyMoePipeline {
fn tokenizer(&self) -> Option<Arc<tokenizers::Tokenizer>> {
get_mut_arcmutex!(self.target).tokenizer()
}
fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
None
}
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -469,6 +473,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
input_processor_cfg.clone(),
None, // TODO: get block tables/handle it for PagedAttention
None, // TODO: prompt chunking doesn't work.
None,
)
.nth(0)
.unwrap();
Expand Down
4 changes: 4 additions & 0 deletions mistralrs-core/src/pipeline/diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::{
GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
PreProcessingMixin, Processor, TokenSource,
};
use crate::device_map::DeviceMapper;
use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
use crate::paged_attention::AttentionImplementation;
use crate::pipeline::ChatTemplate;
Expand Down Expand Up @@ -296,6 +297,9 @@ impl MetadataMixin for DiffusionPipeline {
fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
None
}
fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
None
}
}

#[async_trait::async_trait]
Expand Down
4 changes: 4 additions & 0 deletions mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use super::{
AdapterActivationMixin, AnyMoePipelineMixin, CacheManagerMixin, EitherCache,
ForwardInputsResult, IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
};
use crate::device_map::DeviceMapper;
use crate::lora::Ordering;
use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
use crate::pipeline::get_chat_template;
Expand Down Expand Up @@ -529,6 +530,9 @@ impl MetadataMixin for GGMLPipeline {
fn get_metadata(&self) -> Arc<GeneralMetadata> {
self.metadata.clone()
}
fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
None
}
}

#[async_trait::async_trait]
Expand Down
Loading
Loading