From 32c1cf3d72d7b32ae19b073b9f201a2b8c973bb2 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 29 Dec 2024 17:16:59 -0500 Subject: [PATCH] Cleaner creation of dummy pa input metadata --- mistralrs-core/src/models/gemma.rs | 11 +++++------ mistralrs-core/src/models/gemma2.rs | 11 +++++------ mistralrs-core/src/models/llama.rs | 11 +++++------ mistralrs-core/src/models/mistral.rs | 11 +++++------ mistralrs-core/src/models/mixtral.rs | 11 +++++------ mistralrs-core/src/models/phi2.rs | 11 +++++------ mistralrs-core/src/models/phi3.rs | 11 +++++------ mistralrs-core/src/models/phi3_5_moe.rs | 11 +++++------ mistralrs-core/src/models/qwen2.rs | 11 +++++------ mistralrs-core/src/models/starcoder2.rs | 11 +++++------ mistralrs-core/src/pipeline/inputs_processor.rs | 13 +++++++++++++ .../src/vision_models/llava/llava_llm/llama.rs | 11 +++++------ .../src/vision_models/llava/llava_llm/mistral.rs | 11 +++++------ mistralrs-core/src/vision_models/phi3/mod.rs | 11 +++++------ 14 files changed, 78 insertions(+), 78 deletions(-) diff --git a/mistralrs-core/src/models/gemma.rs b/mistralrs-core/src/models/gemma.rs index 044d2e428..e10fb29b8 100644 --- a/mistralrs-core/src/models/gemma.rs +++ b/mistralrs-core/src/models/gemma.rs @@ -319,12 +319,11 @@ impl Attention { None, )?, None => { - let mut input_metadata = PagedAttentionInputMetadata { - block_tables: None, - context_lens: None, - max_context_len: None, - slot_mappings: Tensor::new(&[0f32], q.device())?, - }; + // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that. + // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts). + let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?; + // Sanity check. + assert!(attention_mask.is_some()); paged_attn.forward( &q, &k, diff --git a/mistralrs-core/src/models/gemma2.rs b/mistralrs-core/src/models/gemma2.rs index ad1e97c0b..ad78cbcf6 100644 --- a/mistralrs-core/src/models/gemma2.rs +++ b/mistralrs-core/src/models/gemma2.rs @@ -333,12 +333,11 @@ impl Attention { self.attn_logit_softcapping, )?, None => { - let mut input_metadata = PagedAttentionInputMetadata { - block_tables: None, - context_lens: None, - max_context_len: None, - slot_mappings: Tensor::new(&[0f32], q.device())?, - }; + // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that. + // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts). + let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?; + // Sanity check. + assert!(attention_mask.is_some()); paged_attn.forward( &q, &k, diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 2f298c9fc..4949f83bf 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -135,12 +135,11 @@ impl CausalSelfAttention { None, )?, None => { - let mut input_metadata = PagedAttentionInputMetadata { - block_tables: None, - context_lens: None, - max_context_len: None, - slot_mappings: Tensor::new(&[0f32], q.device())?, - }; + // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that. + // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts). + let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?; + // Sanity check. + assert!(attention_mask.is_some()); paged_attn.forward( &q, &k, diff --git a/mistralrs-core/src/models/mistral.rs b/mistralrs-core/src/models/mistral.rs index 968364ab1..28b3b890e 100644 --- a/mistralrs-core/src/models/mistral.rs +++ b/mistralrs-core/src/models/mistral.rs @@ -296,12 +296,11 @@ impl Attention { None, )?, None => { - let mut input_metadata = PagedAttentionInputMetadata { - block_tables: None, - context_lens: None, - max_context_len: None, - slot_mappings: Tensor::new(&[0f32], q.device())?, - }; + // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that. + // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts). + let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?; + // Sanity check. + assert!(attention_mask.is_some()); paged_attn.forward( &q, &k, diff --git a/mistralrs-core/src/models/mixtral.rs b/mistralrs-core/src/models/mixtral.rs index 732615341..5d21dc4e0 100644 --- a/mistralrs-core/src/models/mixtral.rs +++ b/mistralrs-core/src/models/mixtral.rs @@ -191,12 +191,11 @@ impl Attention { None, )?, None => { - let mut input_metadata = PagedAttentionInputMetadata { - block_tables: None, - context_lens: None, - max_context_len: None, - slot_mappings: Tensor::new(&[0f32], q.device())?, - }; + // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that. + // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts). + let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?; + // Sanity check. + assert!(attention_mask.is_some()); paged_attn.forward( &q, &k, diff --git a/mistralrs-core/src/models/phi2.rs b/mistralrs-core/src/models/phi2.rs index 4c7b3f58d..d32f20cfb 100644 --- a/mistralrs-core/src/models/phi2.rs +++ b/mistralrs-core/src/models/phi2.rs @@ -315,12 +315,11 @@ impl Attention { None, )?, None => { - let mut input_metadata = PagedAttentionInputMetadata { - block_tables: None, - context_lens: None, - max_context_len: None, - slot_mappings: Tensor::new(&[0f32], q.device())?, - }; + // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that. + // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts). + let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?; + // Sanity check. + assert!(mask.is_some()); paged_attn.forward(&q, &k, &v, mask, None, None, &mut input_metadata, None)? } }, diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index 179b8a215..9853d8111 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -197,12 +197,11 @@ impl Attention { None, )?, None => { - let mut input_metadata = PagedAttentionInputMetadata { - block_tables: None, - context_lens: None, - max_context_len: None, - slot_mappings: Tensor::new(&[0f32], q.device())?, - }; + // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that. + // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts). + let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?; + // Sanity check. + assert!(attention_mask.is_some()); paged_attn.forward( &q, &k, diff --git a/mistralrs-core/src/models/phi3_5_moe.rs b/mistralrs-core/src/models/phi3_5_moe.rs index d2d197f9c..c805cc35a 100644 --- a/mistralrs-core/src/models/phi3_5_moe.rs +++ b/mistralrs-core/src/models/phi3_5_moe.rs @@ -210,12 +210,11 @@ impl Attention { None, )?, None => { - let mut input_metadata = PagedAttentionInputMetadata { - block_tables: None, - context_lens: None, - max_context_len: None, - slot_mappings: Tensor::new(&[0f32], q.device())?, - }; + // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that. + // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts). + let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?; + // Sanity check. + assert!(attention_mask.is_some()); paged_attn.forward( &q, &k, diff --git a/mistralrs-core/src/models/qwen2.rs b/mistralrs-core/src/models/qwen2.rs index c694420f5..306e5c6bb 100644 --- a/mistralrs-core/src/models/qwen2.rs +++ b/mistralrs-core/src/models/qwen2.rs @@ -288,12 +288,11 @@ impl Attention { None, )?, None => { - let mut input_metadata = PagedAttentionInputMetadata { - block_tables: None, - context_lens: None, - max_context_len: None, - slot_mappings: Tensor::new(&[0f32], q.device())?, - }; + // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that. + // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts). + let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?; + // Sanity check. + assert!(attention_mask.is_some()); paged_attn.forward( &q, &k, diff --git a/mistralrs-core/src/models/starcoder2.rs b/mistralrs-core/src/models/starcoder2.rs index 8262fcccf..b6679b782 100644 --- a/mistralrs-core/src/models/starcoder2.rs +++ b/mistralrs-core/src/models/starcoder2.rs @@ -281,12 +281,11 @@ impl Attention { None, )?, None => { - let mut input_metadata = PagedAttentionInputMetadata { - block_tables: None, - context_lens: None, - max_context_len: None, - slot_mappings: Tensor::new(&[0f32], q.device())?, - }; + // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that. + // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts). + let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?; + // Sanity check. + assert!(attention_mask.is_some()); paged_attn.forward( &q, &k, diff --git a/mistralrs-core/src/pipeline/inputs_processor.rs b/mistralrs-core/src/pipeline/inputs_processor.rs index 9a6d5903c..3af61ccd5 100644 --- a/mistralrs-core/src/pipeline/inputs_processor.rs +++ b/mistralrs-core/src/pipeline/inputs_processor.rs @@ -95,6 +95,19 @@ pub mod text_models_inputs_processor { pub max_context_len: Option, } + impl PagedAttentionInputMetadata { + /// Create a dummy input metadata, assuming that this will NOT be used for decoding. + /// This is used for the case of imatrix generation. + pub fn dummy(dev: &Device) -> candle_core::Result { + Ok(PagedAttentionInputMetadata { + block_tables: None, + context_lens: None, + max_context_len: None, + slot_mappings: Tensor::new(&[0f32], dev)?, + }) + } + } + #[derive(Clone, Debug)] pub struct FlashParams { pub max_q: u32, diff --git a/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs b/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs index 64e4a7cfe..ce3be175f 100644 --- a/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs +++ b/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs @@ -100,12 +100,11 @@ impl CausalSelfAttention { None, )?, None => { - let mut input_metadata = PagedAttentionInputMetadata { - block_tables: None, - context_lens: None, - max_context_len: None, - slot_mappings: Tensor::new(&[0f32], q.device())?, - }; + // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that. + // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts). + let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?; + // Sanity check. + assert!(attention_mask.is_some()); paged_attn.forward( &q, &k, diff --git a/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs b/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs index f79c24b15..304693d85 100644 --- a/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs +++ b/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs @@ -247,12 +247,11 @@ impl Attention { None, )?, None => { - let mut input_metadata = PagedAttentionInputMetadata { - block_tables: None, - context_lens: None, - max_context_len: None, - slot_mappings: Tensor::new(&[0f32], q.device())?, - }; + // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that. + // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts). + let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?; + // Sanity check. + assert!(attention_mask.is_some()); paged_attn.forward( &q, &k, diff --git a/mistralrs-core/src/vision_models/phi3/mod.rs b/mistralrs-core/src/vision_models/phi3/mod.rs index 9845cb18a..628ebb7ec 100644 --- a/mistralrs-core/src/vision_models/phi3/mod.rs +++ b/mistralrs-core/src/vision_models/phi3/mod.rs @@ -277,12 +277,11 @@ impl Attention { None, )?, None => { - let mut input_metadata = PagedAttentionInputMetadata { - block_tables: None, - context_lens: None, - max_context_len: None, - slot_mappings: Tensor::new(&[0f32], q.device())?, - }; + // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that. + // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts). + let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?; + // Sanity check. + assert!(attention_mask.is_some()); paged_attn.forward( &q, &k,