Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial fix phi3 device mapping #1002

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mistralrs-core/src/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ impl Sdpa {
sdpa_params: &SdpaParams,
) -> Result<Tensor> {
let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
// Move q to the same device as k,v,mask
let q = q.to_device(k.device())?;
let q = q.contiguous()?;
if sdpa_params.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
let q = q.transpose(1, 2)?;
Expand All @@ -259,7 +262,7 @@ impl Sdpa {

if q.device().is_metal() && seq_len == 1 {
return candle_nn::ops::sdpa(
q,
&q,
k,
v,
sdpa_params.softmax_scale,
Expand Down Expand Up @@ -332,10 +335,10 @@ impl Sdpa {
}
} else {
// Use the f16 kernels here if quantized (ISQ or GGML), and a large enough prompt
naive_sdpa(q, &k, &v, mask, head_dim, sdpa_params)
naive_sdpa(&q, &k, &v, mask, head_dim, sdpa_params)
}
} else {
naive_sdpa(q, &k, &v, mask, head_dim, sdpa_params)
naive_sdpa(&q, &k, &v, mask, head_dim, sdpa_params)
}
}
}
23 changes: 11 additions & 12 deletions mistralrs-core/src/pipeline/cache_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,14 @@ impl SingleCache {
}
let mut shape = src.dims().to_vec();
shape[self.dim] = self.capacity_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
let ad = Tensor::zeros(shape, src.dtype(), self.all_data.as_ref().unwrap().device())?;
ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?;
self.all_data = Some(ad);
}
let ad = self.all_data.as_mut().unwrap();
ad.slice_set(src, self.dim, self.current_seq_len)?;
let src = src.to_device(ad.device())?;
let src = src.contiguous()?;
ad.slice_set(&src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len;
Ok(())
}
Expand Down Expand Up @@ -195,6 +197,7 @@ impl KvCache {
if let Some(mut mask) = mask.cloned() {
let mask_len = mask.dim(1)?;
mask = mask.narrow(1, mask_len - (sliding_window - 1), sliding_window - 1)?;
mask = mask.to_device(k.device())?;
return Ok((k, v, Some(mask)));
}
}
Expand All @@ -203,25 +206,21 @@ impl KvCache {
}

pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
let k = k.contiguous()?;
let v = v.contiguous()?;
self.k.append(&k)?;
self.v.append(&v)?;
self.k.append(k)?;
self.v.append(v)?;
let out_k = self.k.current_data()?;
let out_v = self.v.current_data()?;

// out_k/v should always be Some because SingleCache::append has `if self.all_data.is_none()` logic to create a Tensor if it is empty
let k = match out_k {
None => {
let mut shape = k.dims().to_vec();
shape[self.k.dim] = 0;
Tensor::zeros(shape, k.dtype(), k.device())?
unreachable!()
}
Some(k) => k,
};
let v = match out_v {
None => {
let mut shape = v.dims().to_vec();
shape[self.k.dim] = 0;
Tensor::zeros(shape, v.dtype(), v.device())?
unreachable!()
}
Some(v) => v,
};
Expand Down
Loading