Skip to content

Commit

Permalink
Adding streaming function to mistralrs server. (#986)
Browse files Browse the repository at this point in the history
* Adding streaming function to mistralrs server.

* Adding simple_stream example
  • Loading branch information
Narsil authored Dec 12, 2024
1 parent c1e9268 commit 9d1f09f
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 1 deletion.
55 changes: 55 additions & 0 deletions mistralrs/examples/simple_stream/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use anyhow::Result;
use mistralrs::{
IsqType, PagedAttentionMetaBuilder, RequestBuilder, Response, TextMessageRole, TextMessages,
TextModelBuilder,
};
use std::io::Write;

#[tokio::main]
async fn main() -> Result<()> {
let model = TextModelBuilder::new("microsoft/Phi-3.5-mini-instruct")
.with_isq(IsqType::Q8_0)
.with_logging()
.with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?
.build()
.await?;

let messages = TextMessages::new()
.add_message(
TextMessageRole::System,
"You are an AI agent with a specialty in programming.",
)
.add_message(
TextMessageRole::User,
"Hello! How are you? Please write generic binary search function in Rust.",
);

let response = model.send_chat_request(messages).await?;

println!("{}", response.choices[0].message.content.as_ref().unwrap());
dbg!(
response.usage.avg_prompt_tok_per_sec,
response.usage.avg_compl_tok_per_sec
);

// Next example: Return some logprobs with the `RequestBuilder`, which enables higher configurability.
let request = RequestBuilder::new().return_logprobs(true).add_message(
TextMessageRole::User,
"Please write a mathematical equation where a few numbers are added.",
);

let mut stream = model.stream_chat_request(request).await?;

let stdout = std::io::stdout();
let lock = stdout.lock();
let mut buf = std::io::BufWriter::new(lock);
while let Some(chunk) = stream.next().await {
if let Response::Chunk(chunk) = chunk {
buf.write(chunk.choices[0].delta.content.as_bytes())?;

Check failure on line 48 in mistralrs/examples/simple_stream/main.rs

View workflow job for this annotation

GitHub Actions / Clippy

written amount is not handled
} else {
// Handle errors
}
}

Ok(())
}
39 changes: 39 additions & 0 deletions mistralrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,45 @@
//! Ok(())
//! }
//! ```
//!
//! ## Streaming example
//! ```no_run
//! use anyhow::Result;
//! use mistralrs::{
//! IsqType, PagedAttentionMetaBuilder, TextMessageRole, TextMessages, TextModelBuilder,
//! };
//!
//! #[tokio::main]
//! async fn main() -> Result<()> {
//! let model = TextModelBuilder::new("microsoft/Phi-3.5-mini-instruct".to_string())
//! .with_isq(IsqType::Q8_0)
//! .with_logging()
//! .with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?
//! .build()
//! .await?;
//!
//! let messages = TextMessages::new()
//! .add_message(
//! TextMessageRole::System,
//! "You are an AI agent with a specialty in programming.",
//! )
//! .add_message(
//! TextMessageRole::User,
//! "Hello! How are you? Please write generic binary search function in Rust.",
//! );
//!
//! let mut stream = model.stream_chat_request(messages).await?;
//!
//! while let Some(chunk) = stream.next().await {
//! if let Response::Chunk(chunk) = chunk{
//! print!("{}", chunk.choices[0].delta.content);

Check failure on line 83 in mistralrs/src/lib.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

failed to resolve: use of undeclared type `Response`
//! }
//! // Handle the error cases.
//!
//! }
//! Ok(())
//! }
//! ```
mod anymoe;
mod diffusion_model;
Expand Down
48 changes: 47 additions & 1 deletion mistralrs/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use candle_core::{Device, Result, Tensor};
use either::Either;
use mistralrs_core::*;
use std::sync::Arc;
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::{channel, Receiver};

use crate::{RequestLike, TextMessages};

Expand Down Expand Up @@ -47,11 +47,57 @@ pub struct Model {
runner: Arc<MistralRs>,
}

pub struct Stream<'a> {
_server: &'a Model,
rx: Receiver<Response>,
}

impl<'a> Stream<'a> {
pub async fn next(&mut self) -> Option<Response> {
self.rx.recv().await
}
}

impl Model {
pub fn new(runner: Arc<MistralRs>) -> Self {
Self { runner }
}

/// Generate with the model.
pub async fn stream_chat_request<R: RequestLike>(
&self,
mut request: R,
) -> anyhow::Result<Stream> {
let (tx, rx) = channel(1);

let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
(Some(a), Some(b))
} else {
(None, None)
};
let request = Request::Normal(NormalRequest {
messages: request.take_messages(),
sampling_params: request.take_sampling_params(),
response: tx,
return_logprobs: request.return_logprobs(),
is_streaming: true,
id: 0,
constraint: request.take_constraint(),
suffix: None,
adapters: request.take_adapters(),
tools,
tool_choice,
logits_processors: request.take_logits_processors(),
return_raw_logits: false,
});

self.runner.get_sender()?.send(request).await?;

let stream = Stream { _server: self, rx };

Ok(stream)
}

/// Generate with the model.
pub async fn send_chat_request<R: RequestLike>(
&self,
Expand Down

0 comments on commit 9d1f09f

Please sign in to comment.