Skip to content

Commit

Permalink
Add device mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
cdoko authored Dec 31, 2024
1 parent e935067 commit a833acf
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use super::{
AdapterActivationMixin, AnyMoePipelineMixin, CacheManagerMixin, EitherCache,
ForwardInputsResult, IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
};
use crate::device_map::DeviceMapper;
use crate::gguf::{
get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion},
};
Expand Down Expand Up @@ -74,6 +75,7 @@ pub struct GGUFPipeline {
model_id: String,
non_granular_state: Option<NonGranularState>,
metadata: Arc<GeneralMetadata>,
mapper: Box<dyn DeviceMapper + Send + Sync>,
}

/// Loader for a GGUF model.
Expand Down Expand Up @@ -354,10 +356,6 @@ impl Loader for GGUFLoader {
device.device_pretty_repr()
);
}
// } else if paged_attn_config.is_some() {
// warn!("Device mapping or device topology and PagedAttention are incompatible, disabling PagedAttention.");
// paged_attn_config = None;
// }

let mut readers = Vec::new();
for filename in paths.get_weight_filenames() {
Expand Down Expand Up @@ -484,16 +482,8 @@ impl Loader for GGUFLoader {
model_config,
device,
)?;
let cache_config = calculate_cache_config(
paged_attn_config.mem_gpu,
paged_attn_config.mem_cpu,
paged_attn_config.block_size,
internal_dtype,
model_config,
device,
)?;
let cache_engine =
CacheEngine::new(model_config, &cache_config, internal_dtype, device)?;
CacheEngine::new(model_config, &cache_config, internal_dtype, device, layer_devices)?;
(Some(cache_config), Some(cache_engine))
} else {
(None, None)
Expand Down Expand Up @@ -524,6 +514,15 @@ impl Loader for GGUFLoader {
Model::Qwen2(ref p) => p.max_seq_len,
};
let tok_env = build_tok_env(tokenizer.clone());
let num_hidden_layers = match model {
Model::Llama(ref model) => model.cache.normal().0.len(),
Model::Phi2(ref model) => model.cache.normal().0.len(),
Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
Model::Phi3(ref model) => model.cache.normal().0.len(),
Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
Model::Starcoder2(ref model) => model.cache.normal().0.len(),
Model::Qwen2(ref model) => model.cache.normal().0.len(),
};

if chat_template.bos_token.is_none() && bos.is_some() {
chat_template.bos_token = Some(BeginEndUnkTok(Either::Left(bos.unwrap())));
Expand Down Expand Up @@ -566,6 +565,7 @@ impl Loader for GGUFLoader {
prompt_batchsize: self.config.prompt_batchsize,
model_metadata: Some(Arc::new(model_config_metadata)),
}),
mapper,
})))
}

Expand Down Expand Up @@ -693,6 +693,9 @@ impl MetadataMixin for GGUFPipeline {
fn get_metadata(&self) -> Arc<GeneralMetadata> {
self.metadata.clone()
}
fn device_mapper(&self) -> Option<&dyn DeviceMapper> {

Check failure on line 696 in mistralrs-core/src/pipeline/gguf.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

method `device_mapper` is not a member of trait `MetadataMixin`

Check failure on line 696 in mistralrs-core/src/pipeline/gguf.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

method `device_mapper` is not a member of trait `MetadataMixin`

Check failure on line 696 in mistralrs-core/src/pipeline/gguf.rs

View workflow job for this annotation

GitHub Actions / Clippy

method `device_mapper` is not a member of trait `MetadataMixin`

Check failure on line 696 in mistralrs-core/src/pipeline/gguf.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

method `device_mapper` is not a member of trait `MetadataMixin`

Check failure on line 696 in mistralrs-core/src/pipeline/gguf.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

method `device_mapper` is not a member of trait `MetadataMixin`

Check failure on line 696 in mistralrs-core/src/pipeline/gguf.rs

View workflow job for this annotation

GitHub Actions / Docs

method `device_mapper` is not a member of trait `MetadataMixin`
Some(&*self.mapper)
}
}

#[async_trait::async_trait]
Expand Down

0 comments on commit a833acf

Please sign in to comment.