-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support streaming batched chat completion requests
Support streaming batched chat completion requests
- Loading branch information
Showing
25 changed files
with
1,062 additions
and
570 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import openai | ||
import asyncio | ||
from openai import Stream | ||
from openai.types.chat import ChatCompletionChunk | ||
from typing import List | ||
# Run: cargo run --release -- --port 2000 --model-id <MODEL_ID> <MODEL_TYPE> --repeat-last-n 64 | ||
# MODEL_ID is the huggingface model id or local weight path | ||
# MODEL_TYPE is one of ["llama", "llama3", "mistral", "phi2", "phi3", "qwen2", "gemma", "yi", "stable-lm"] | ||
|
||
|
||
openai.api_key = "EMPTY" | ||
|
||
openai.base_url = "http://localhost:2000/v1/" | ||
|
||
async def chat_completion(model, max_tokens, prompt): | ||
completion = openai.chat.completions.create( | ||
model=model, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": prompt, | ||
}, | ||
], | ||
max_tokens = max_tokens, | ||
stream=True, | ||
) | ||
return completion | ||
|
||
async def stream_response(response_idx, stream: Stream[ChatCompletionChunk]): | ||
result = "" | ||
for o in stream: | ||
r = o.choices[0].delta.content | ||
if r != None: | ||
result += r | ||
return (response_idx, result) | ||
|
||
async def benchmark(): | ||
model = "mistral7b" | ||
max_tokens = 1024 | ||
# 16 requests | ||
prompts = ["Explain how to best learn Rust.", | ||
"Please talk about deep learning in 100 words.", | ||
"Do you know the capital city of China? Talk the details of you known.", | ||
"Who is the best female actor in the world? Explain why.", | ||
"How to dealing with depression?", | ||
"How to make money in short time?", | ||
"What is the future trend of large language model?", | ||
"The famous tech companies in the world.", | ||
"Explain how to best learn Rust.", | ||
"Please talk about deep learning in 100 words.", | ||
"Do you know the capital city of China? Talk the details of you known.", | ||
"Who is the best female actor in the world? Explain why.", | ||
"How to dealing with depression?", | ||
"How to make money in short time?", | ||
"What is the future trend of large language model?", | ||
"The famous tech companies in the world."] | ||
|
||
# send 16 chat requests at the same time | ||
tasks: List[asyncio.Task] = [] | ||
for i in range(len(prompts)): | ||
tasks.append( | ||
asyncio.create_task( | ||
chat_completion(model, max_tokens, prompts[i])) | ||
) | ||
|
||
# obtain the correspond stream object for each request | ||
outputs: List[Stream[ChatCompletionChunk]] = await asyncio.gather(*tasks) | ||
|
||
# tasks for streaming chat responses | ||
tasks_stream: List[asyncio.Task] = [] | ||
for i in range(len(outputs)): | ||
tasks_stream.append( | ||
asyncio.create_task( | ||
stream_response(i, outputs[i])) | ||
) | ||
|
||
# gathering the response texts | ||
outputs: List[(int, str)] = await asyncio.gather(*tasks_stream) | ||
|
||
# print the results, you may find chat completion statistics in the backend server (i.e., candle-vllm) | ||
for idx, output in outputs: | ||
print("\n\n Response {}: \n\n {}".format(idx, output)) | ||
|
||
|
||
asyncio.run(benchmark()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
use crate::candle::D; | ||
use crate::candle::{DType, Error, Result, Tensor}; | ||
use rand::{distributions::Distribution, SeedableRng}; | ||
#[derive(Clone, PartialEq, Debug)] | ||
pub enum Sampling { | ||
ArgMax, | ||
All { temperature: f64 }, | ||
TopK { k: usize, temperature: f64 }, | ||
TopP { p: f64, temperature: f64 }, | ||
TopKThenTopP { k: usize, p: f64, temperature: f64 }, | ||
} | ||
|
||
pub struct LogitsProcessor { | ||
rng: rand::rngs::StdRng, | ||
sampling: Sampling, | ||
} | ||
|
||
impl LogitsProcessor { | ||
pub fn from_sampling(seed: u64, sampling: Sampling) -> Self { | ||
let rng = rand::rngs::StdRng::seed_from_u64(seed); | ||
Self { rng, sampling } | ||
} | ||
|
||
pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self { | ||
let temperature = temperature.and_then(|v| if v < 1e-7 { None } else { Some(v) }); | ||
let sampling = match temperature { | ||
None => Sampling::ArgMax, | ||
Some(temperature) => match top_p { | ||
None => Sampling::All { temperature }, | ||
Some(p) => Sampling::TopP { p, temperature }, | ||
}, | ||
}; | ||
Self::from_sampling(seed, sampling) | ||
} | ||
|
||
fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> { | ||
// let logits_v: Vec<f32> = logits.to_vec1()?; | ||
// Use gpu kernel | ||
let next_token = logits.argmax(D::Minus1)?.to_scalar::<u32>()?; | ||
// let next_token = logits_v | ||
// .iter() | ||
// .enumerate() | ||
// .max_by(|(_, u), (_, v)| u.total_cmp(v)) | ||
// .map(|(i, _)| i as u32) | ||
// .unwrap(); | ||
Ok(next_token) | ||
} | ||
|
||
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> { | ||
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; | ||
let next_token = distr.sample(&mut self.rng) as u32; | ||
Ok(next_token) | ||
} | ||
|
||
/// top-p sampling (or "nucleus sampling") samples from the smallest set of tokens that exceed | ||
/// probability top_p. This way we never sample tokens that have very low probabilities and are | ||
/// less likely to go "off the rails". | ||
fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> { | ||
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>(); | ||
|
||
// Sort by descending probability. | ||
argsort_indices.sort_by(|&i, &j| prs[j].total_cmp(&prs[i])); | ||
|
||
// Clamp smaller probabilities to zero. | ||
let mut cumsum = 0.; | ||
for index in &argsort_indices { | ||
if cumsum >= top_p { | ||
prs[*index] = 0.0; | ||
} else { | ||
cumsum += prs[*index]; | ||
} | ||
} | ||
// Sample with clamped probabilities. | ||
self.sample_multinomial(prs) | ||
} | ||
|
||
// top-k sampling samples from the k tokens with the largest probabilities. | ||
fn sample_topk(&mut self, prs: &mut Vec<f32>, top_k: usize) -> Result<u32> { | ||
if top_k >= prs.len() { | ||
self.sample_multinomial(prs) | ||
} else { | ||
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>(); | ||
let (indices, _, _) = | ||
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i])); | ||
let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>(); | ||
let index = self.sample_multinomial(&prs)?; | ||
Ok(indices[index as usize] as u32) | ||
} | ||
} | ||
|
||
// top-k sampling samples from the k tokens with the largest probabilities. | ||
// then top-p sampling. | ||
fn sample_topk_topp(&mut self, prs: &mut Vec<f32>, top_k: usize, top_p: f32) -> Result<u32> { | ||
if top_k >= prs.len() { | ||
self.sample_topp(prs, top_p) | ||
} else { | ||
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>(); | ||
let (indices, _, _) = | ||
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i])); | ||
let mut prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>(); | ||
let sum_p = prs.iter().sum::<f32>(); | ||
let index = if top_p <= 0.0 || top_p >= sum_p { | ||
self.sample_multinomial(&prs)? | ||
} else { | ||
self.sample_topp(&mut prs, top_p)? | ||
}; | ||
Ok(indices[index as usize] as u32) | ||
} | ||
} | ||
|
||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { | ||
self.sample_f(logits, |_| {}) | ||
} | ||
|
||
pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result<u32> { | ||
let logits = logits.to_dtype(DType::F32)?; | ||
let prs = |temperature: f64| -> Result<Vec<f32>> { | ||
let logits = (&logits / temperature)?; | ||
let prs = candle_nn::ops::softmax_last_dim(&logits)?; | ||
let mut prs = prs.to_vec1()?; | ||
f(&mut prs); | ||
Ok(prs) | ||
}; | ||
|
||
let next_token = match &self.sampling { | ||
Sampling::ArgMax => self.sample_argmax(logits)?, | ||
Sampling::All { temperature } => { | ||
let prs = prs(*temperature)?; | ||
self.sample_multinomial(&prs)? | ||
} | ||
Sampling::TopP { p, temperature } => { | ||
let mut prs = prs(*temperature)?; | ||
if *p <= 0.0 || *p >= 1.0 { | ||
// simply sample from the predicted probability distribution | ||
self.sample_multinomial(&prs)? | ||
} else { | ||
// top-p (nucleus) sampling, clamping the least likely tokens to zero | ||
self.sample_topp(&mut prs, *p as f32)? | ||
} | ||
} | ||
Sampling::TopK { k, temperature } => { | ||
let mut prs = prs(*temperature)?; | ||
self.sample_topk(&mut prs, *k)? | ||
} | ||
Sampling::TopKThenTopP { k, p, temperature } => { | ||
let mut prs = prs(*temperature)?; | ||
self.sample_topk_topp(&mut prs, *k, *p as f32)? | ||
} | ||
}; | ||
Ok(next_token) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.