Skip to content

Commit

Permalink
Cleaner creation of dummy pa input metadata (#1014)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Dec 29, 2024
1 parent d28ddf9 commit d8fa819
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 78 deletions.
11 changes: 5 additions & 6 deletions mistralrs-core/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,11 @@ impl Attention {
None,
)?,
None => {
let mut input_metadata = PagedAttentionInputMetadata {
block_tables: None,
context_lens: None,
max_context_len: None,
slot_mappings: Tensor::new(&[0f32], q.device())?,
};
// If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
// Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
// Sanity check.
assert!(attention_mask.is_some());
paged_attn.forward(
&q,
&k,
Expand Down
11 changes: 5 additions & 6 deletions mistralrs-core/src/models/gemma2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,12 +333,11 @@ impl Attention {
self.attn_logit_softcapping,
)?,
None => {
let mut input_metadata = PagedAttentionInputMetadata {
block_tables: None,
context_lens: None,
max_context_len: None,
slot_mappings: Tensor::new(&[0f32], q.device())?,
};
// If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
// Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
// Sanity check.
assert!(attention_mask.is_some());
paged_attn.forward(
&q,
&k,
Expand Down
11 changes: 5 additions & 6 deletions mistralrs-core/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,11 @@ impl CausalSelfAttention {
None,
)?,
None => {
let mut input_metadata = PagedAttentionInputMetadata {
block_tables: None,
context_lens: None,
max_context_len: None,
slot_mappings: Tensor::new(&[0f32], q.device())?,
};
// If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
// Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
// Sanity check.
assert!(attention_mask.is_some());
paged_attn.forward(
&q,
&k,
Expand Down
11 changes: 5 additions & 6 deletions mistralrs-core/src/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,11 @@ impl Attention {
None,
)?,
None => {
let mut input_metadata = PagedAttentionInputMetadata {
block_tables: None,
context_lens: None,
max_context_len: None,
slot_mappings: Tensor::new(&[0f32], q.device())?,
};
// If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
// Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
// Sanity check.
assert!(attention_mask.is_some());
paged_attn.forward(
&q,
&k,
Expand Down
11 changes: 5 additions & 6 deletions mistralrs-core/src/models/mixtral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,11 @@ impl Attention {
None,
)?,
None => {
let mut input_metadata = PagedAttentionInputMetadata {
block_tables: None,
context_lens: None,
max_context_len: None,
slot_mappings: Tensor::new(&[0f32], q.device())?,
};
// If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
// Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
// Sanity check.
assert!(attention_mask.is_some());
paged_attn.forward(
&q,
&k,
Expand Down
11 changes: 5 additions & 6 deletions mistralrs-core/src/models/phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,11 @@ impl Attention {
None,
)?,
None => {
let mut input_metadata = PagedAttentionInputMetadata {
block_tables: None,
context_lens: None,
max_context_len: None,
slot_mappings: Tensor::new(&[0f32], q.device())?,
};
// If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
// Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
// Sanity check.
assert!(mask.is_some());
paged_attn.forward(&q, &k, &v, mask, None, None, &mut input_metadata, None)?
}
},
Expand Down
11 changes: 5 additions & 6 deletions mistralrs-core/src/models/phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,11 @@ impl Attention {
None,
)?,
None => {
let mut input_metadata = PagedAttentionInputMetadata {
block_tables: None,
context_lens: None,
max_context_len: None,
slot_mappings: Tensor::new(&[0f32], q.device())?,
};
// If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
// Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
// Sanity check.
assert!(attention_mask.is_some());
paged_attn.forward(
&q,
&k,
Expand Down
11 changes: 5 additions & 6 deletions mistralrs-core/src/models/phi3_5_moe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,11 @@ impl Attention {
None,
)?,
None => {
let mut input_metadata = PagedAttentionInputMetadata {
block_tables: None,
context_lens: None,
max_context_len: None,
slot_mappings: Tensor::new(&[0f32], q.device())?,
};
// If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
// Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
// Sanity check.
assert!(attention_mask.is_some());
paged_attn.forward(
&q,
&k,
Expand Down
11 changes: 5 additions & 6 deletions mistralrs-core/src/models/qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,11 @@ impl Attention {
None,
)?,
None => {
let mut input_metadata = PagedAttentionInputMetadata {
block_tables: None,
context_lens: None,
max_context_len: None,
slot_mappings: Tensor::new(&[0f32], q.device())?,
};
// If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
// Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
// Sanity check.
assert!(attention_mask.is_some());
paged_attn.forward(
&q,
&k,
Expand Down
11 changes: 5 additions & 6 deletions mistralrs-core/src/models/starcoder2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,11 @@ impl Attention {
None,
)?,
None => {
let mut input_metadata = PagedAttentionInputMetadata {
block_tables: None,
context_lens: None,
max_context_len: None,
slot_mappings: Tensor::new(&[0f32], q.device())?,
};
// If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
// Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
// Sanity check.
assert!(attention_mask.is_some());
paged_attn.forward(
&q,
&k,
Expand Down
13 changes: 13 additions & 0 deletions mistralrs-core/src/pipeline/inputs_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,19 @@ pub mod text_models_inputs_processor {
pub max_context_len: Option<usize>,
}

impl PagedAttentionInputMetadata {
/// Create a dummy input metadata, assuming that this will NOT be used for decoding.
/// This is used for the case of imatrix generation.
pub fn dummy(dev: &Device) -> candle_core::Result<Self> {
Ok(PagedAttentionInputMetadata {
block_tables: None,
context_lens: None,
max_context_len: None,
slot_mappings: Tensor::new(&[0f32], dev)?,
})
}
}

#[derive(Clone, Debug)]
pub struct FlashParams {
pub max_q: u32,
Expand Down
11 changes: 5 additions & 6 deletions mistralrs-core/src/vision_models/llava/llava_llm/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,11 @@ impl CausalSelfAttention {
None,
)?,
None => {
let mut input_metadata = PagedAttentionInputMetadata {
block_tables: None,
context_lens: None,
max_context_len: None,
slot_mappings: Tensor::new(&[0f32], q.device())?,
};
// If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
// Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
// Sanity check.
assert!(attention_mask.is_some());
paged_attn.forward(
&q,
&k,
Expand Down
11 changes: 5 additions & 6 deletions mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,11 @@ impl Attention {
None,
)?,
None => {
let mut input_metadata = PagedAttentionInputMetadata {
block_tables: None,
context_lens: None,
max_context_len: None,
slot_mappings: Tensor::new(&[0f32], q.device())?,
};
// If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
// Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
// Sanity check.
assert!(attention_mask.is_some());
paged_attn.forward(
&q,
&k,
Expand Down
11 changes: 5 additions & 6 deletions mistralrs-core/src/vision_models/phi3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,11 @@ impl Attention {
None,
)?,
None => {
let mut input_metadata = PagedAttentionInputMetadata {
block_tables: None,
context_lens: None,
max_context_len: None,
slot_mappings: Tensor::new(&[0f32], q.device())?,
};
// If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
// Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
let mut input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
// Sanity check.
assert!(attention_mask.is_some());
paged_attn.forward(
&q,
&k,
Expand Down

0 comments on commit d8fa819

Please sign in to comment.