Skip to content

Commit

Permalink
Fix for vllama
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jan 2, 2025
1 parent ba379c5 commit dc3a4c3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 13 deletions.
7 changes: 2 additions & 5 deletions mistralrs-core/src/pipeline/speculative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,7 @@ impl Pipeline for SpeculativePipeline {
for i in 0..self.gamma {
let is_xlora = get_mut_arcmutex!(self.draft).get_metadata().is_xlora;
let device = get_mut_arcmutex!(self.draft).device();
let no_kv_cache =
get_mut_arcmutex!(self.draft).get_metadata().no_kv_cache;
let no_kv_cache = get_mut_arcmutex!(self.draft).get_metadata().no_kv_cache;
let inputs = self
.get_processor()
.inputs_processor()
Expand Down Expand Up @@ -492,9 +491,7 @@ impl Pipeline for SpeculativePipeline {
// ========= Run the model ============
let is_xlora = get_mut_arcmutex!(self.target).get_metadata().is_xlora;
let device = get_mut_arcmutex!(self.target).device();
let no_kv_cache = get_mut_arcmutex!(self.target)
.get_metadata()
.no_kv_cache;
let no_kv_cache = get_mut_arcmutex!(self.target).get_metadata().no_kv_cache;
let inputs = self
.get_processor()
.inputs_processor()
Expand Down
12 changes: 4 additions & 8 deletions mistralrs-core/src/vision_models/mllama/text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ impl MLlamaTextCrossAttention {
hidden_states: &Tensor,
cross_attn_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
kv_cache: &mut KvCache,
) -> Result<Tensor> {
let (bs, q_len, _) = hidden_states.dims3()?;

Expand Down Expand Up @@ -385,10 +384,9 @@ impl MLlamaTextCrossAttention {
.reshape((bs, (), self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;

(k, v) = kv_cache.append(&k, &v)?;
(k, v)
} else {
candle_core::bail!("Cross attn cannot find cross attn hidden states!")
candle_core::bail!("Cross attn cannot find k,v cache or cross attn hidden states!")
};

let mut attn_output = Sdpa
Expand Down Expand Up @@ -473,15 +471,14 @@ impl MLlamaCrossAttentionDecoderLayer {
cross_attn_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
full_text_row_masked_out_mask: Option<&Tensor>,
kv_cache: &mut KvCache,
) -> Result<Tensor> {
let residual = hidden_states;

let mut hidden_states = self.input_layernorm.forward(hidden_states)?;

hidden_states =
self.attn
.forward(&hidden_states, cross_attn_states, attention_mask, kv_cache)?;
hidden_states = self
.attn
.forward(&hidden_states, cross_attn_states, attention_mask)?;
hidden_states = (residual + hidden_states.broadcast_mul(&self.attn_gate.tanh()?)?)?;

let residual = &hidden_states;
Expand Down Expand Up @@ -674,7 +671,6 @@ impl MLlamaTextModel {
cross_attn_states,
cross_attention_mask,
full_text_row_masked_out_mask,
&mut cache[i],
)?;
}
}
Expand Down

0 comments on commit dc3a4c3

Please sign in to comment.