Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store chat history with stream update #56

Merged
merged 10 commits into from
Jul 11, 2024
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ function-calling = ["scraper", "text-splitter", "regex", "chat-history"]
tokio = { version = "1", features = ["full"] }
ollama-rs = { path = ".", features = ["stream", "chat-history", "function-calling"] }
base64 = "0.22.0"

4 changes: 2 additions & 2 deletions examples/chat_with_history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let result = ollama
.send_chat_messages_with_history(
ChatMessageRequest::new("llama2:latest".to_string(), vec![user_message]),
"default".to_string(),
"default",
)
.await?;

Expand All @@ -37,7 +37,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}

// Display whole history of messages
dbg!(&ollama.get_messages_history("default".to_string()));
dbg!(&ollama.get_messages_history("default"));

Ok(())
}
25 changes: 9 additions & 16 deletions examples/chat_with_history_stream.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
use ollama_rs::{
generation::chat::{request::ChatMessageRequest, ChatMessage},
generation::chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponseStream},
Ollama,
};
use tokio::io::{stdout, AsyncWriteExt};
use tokio_stream::StreamExt;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut ollama = Ollama::new_default_with_history_async(30);
let mut ollama = Ollama::new_default_with_history(30);

let mut stdout = stdout();

let chat_id = "default".to_string();

loop {
stdout.write_all(b"\n> ").await?;
stdout.flush().await?;
Expand All @@ -25,12 +23,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
break;
}

let user_message = ChatMessage::user(input.to_string());

let mut stream = ollama
let mut stream: ChatMessageResponseStream = ollama
.send_chat_messages_with_history_stream(
ChatMessageRequest::new("llama2:latest".to_string(), vec![user_message]),
chat_id.clone(),
ChatMessageRequest::new(
"llama2:latest".to_string(),
vec![ChatMessage::user(input.to_string())],
),
"user".to_string(),
)
.await?;

Expand All @@ -44,14 +43,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
response += assistant_message.content.as_str();
}
}
dbg!(&ollama.get_messages_history("user"));
}

// Display whole history of messages
dbg!(
&ollama
.get_messages_history_async("default".to_string())
.await
);

Ok(())
}
190 changes: 83 additions & 107 deletions src/generation/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ use request::ChatMessageRequest;

#[cfg(feature = "chat-history")]
use crate::history::MessagesHistory;
#[cfg(all(feature = "chat-history", feature = "stream"))]
use crate::history_async::MessagesHistoryAsync;

#[cfg(feature = "stream")]
/// A stream of `ChatMessageResponse` objects
Expand Down Expand Up @@ -99,15 +97,69 @@ impl Ollama {

#[cfg(feature = "chat-history")]
impl Ollama {
#[cfg(feature = "stream")]
pub async fn send_chat_messages_with_history_stream<S: Into<String> + Clone>(
&mut self,
mut request: ChatMessageRequest,
history_id: S,
ushinnary marked this conversation as resolved.
Show resolved Hide resolved
) -> crate::error::Result<ChatMessageResponseStream> {
use async_stream::stream;
use tokio_stream::StreamExt;
let id_copy = history_id.clone().into();

let mut current_chat_messages = self.get_chat_messages_by_id(id_copy.clone());

if let Some(message) = request.messages.first() {
current_chat_messages.push(message.clone());
}

// The request is modified to include the current chat messages
request.messages.clone_from(&current_chat_messages);
request.stream = true;

let mut resp_stream: ChatMessageResponseStream =
self.send_chat_messages_stream(request.clone()).await?;

let messages_history = self.messages_history.clone();

let s = stream! {
let mut result = String::new();

while let Some(item) = resp_stream.try_next().await.unwrap() {
let msg_part = item.clone().message.unwrap().content;

if item.done {
if let Some(history) = messages_history.clone() {
let mut inner = history.write().unwrap();
// Message we sent to AI
if let Some(message) = request.messages.last() {
inner.add_message(id_copy.clone(), message.clone());
}

// AI's response
inner.add_message(id_copy.clone(), ChatMessage::assistant(result.clone()));
}
} else {
result.push_str(&msg_part);
}

yield Ok(item);
}
};

Ok(Box::pin(s))
}

/// Chat message generation
/// Returns a `ChatMessageResponse` object
/// Manages the history of messages for the given `id`
pub async fn send_chat_messages_with_history(
pub async fn send_chat_messages_with_history<S: Into<String> + Clone>(
&mut self,
mut request: ChatMessageRequest,
id: String,
history_id: S,
) -> crate::error::Result<ChatMessageResponse> {
let mut current_chat_messages = self.get_chat_messages_by_id(id.clone());
// The request is modified to include the current chat messages
let mut current_chat_messages = self.get_chat_messages_by_id(history_id.clone());

if let Some(message) = request.messages.first() {
current_chat_messages.push(message.clone());
Expand All @@ -121,125 +173,49 @@ impl Ollama {
if let Ok(result) = result {
// Message we sent to AI
if let Some(message) = request.messages.last() {
self.store_chat_message_by_id(id.clone(), message.clone());
self.store_chat_message_by_id(history_id.clone(), message.clone());
}
// AI's response store in the history
self.store_chat_message_by_id(id, result.message.clone().unwrap());
// Store AI's response in the history
self.store_chat_message_by_id(history_id, result.message.clone().unwrap());

return Ok(result);
}

result
}

/// Helper function to get chat messages by id
fn get_chat_messages_by_id(&mut self, id: String) -> Vec<ChatMessage> {
let mut backup = MessagesHistory::default();

// Clone the current chat messages to avoid borrowing issues
// And not to add message to the history if the request fails
self.messages_history
.as_mut()
.unwrap_or(&mut backup)
.messages_by_id
.entry(id.clone())
.or_default()
.clone()
}

/// Helper function to store chat messages by id
fn store_chat_message_by_id(&mut self, id: String, message: ChatMessage) {
fn store_chat_message_by_id<S: Into<String>>(&mut self, id: S, message: ChatMessage) {
if let Some(messages_history) = self.messages_history.as_mut() {
messages_history.add_message(id, message);
messages_history.write().unwrap().add_message(id, message);
}
}
}

#[cfg(all(feature = "chat-history", feature = "stream"))]
impl Ollama {
async fn get_chat_messages_by_id_async(&mut self, id: String) -> Vec<ChatMessage> {
/// Let get existing history with a new message in it
/// Without impact for existing history
/// Used to prepare history for request
fn get_chat_messages_by_id<S: Into<String> + Clone>(
&mut self,
history_id: S,
) -> Vec<ChatMessage> {
let chat_history = match self.messages_history.as_mut() {
Some(history) => history,
None => &mut {
let new_history =
std::sync::Arc::new(std::sync::RwLock::new(MessagesHistory::default()));
self.messages_history = Some(new_history);
self.messages_history.clone().unwrap()
},
};
// Clone the current chat messages to avoid borrowing issues
// And not to add message to the history if the request fails
self.messages_history_async
.as_mut()
.unwrap_or(&mut MessagesHistoryAsync::default())
let mut history_instance = chat_history.write().unwrap();
let chat_history = history_instance
.messages_by_id
.lock()
.await
.entry(id.clone())
.or_default()
.clone()
}

pub async fn store_chat_message_by_id_async(&mut self, id: String, message: ChatMessage) {
if let Some(messages_history_async) = self.messages_history_async.as_mut() {
messages_history_async.add_message(id, message).await;
}
}

pub async fn send_chat_messages_with_history_stream(
&mut self,
mut request: ChatMessageRequest,
id: String,
) -> crate::error::Result<ChatMessageResponseStream> {
use tokio_stream::StreamExt;

let (tx, mut rx) =
tokio::sync::mpsc::unbounded_channel::<Result<ChatMessageResponse, ()>>(); // create a channel for sending and receiving messages

let mut current_chat_messages = self.get_chat_messages_by_id_async(id.clone()).await;

if let Some(messaeg) = request.messages.first() {
current_chat_messages.push(messaeg.clone());
}

request.messages.clone_from(&current_chat_messages);

let mut stream = self.send_chat_messages_stream(request.clone()).await?;
.entry(history_id.into())
.or_default();

let message_history_async = self.messages_history_async.clone();

tokio::spawn(async move {
let mut result = String::new();
while let Some(res) = rx.recv().await {
match res {
Ok(res) => {
if let Some(message) = res.message.clone() {
result += message.content.as_str();
}
}
Err(_) => {
break;
}
}
}

if let Some(message_history_async) = message_history_async {
message_history_async
.add_message(id.clone(), ChatMessage::assistant(result))
.await;
} else {
eprintln!("not using chat-history and stream features"); // this should not happen if the features are enabled
}
});

let s = stream! {
while let Some(res) = stream.next().await {
match res {
Ok(res) => {
if let Err(e) = tx.send(Ok(res.clone())) {
eprintln!("Failed to send response: {}", e);
};
yield Ok(res);
}
Err(_) => {
yield Err(());
}
}
}
};

Ok(Box::pin(s))
chat_history.clone()
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/generation/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl crate::Ollama {
}

fn has_system_prompt_history(&mut self) -> bool {
return self.get_messages_history("default".to_string()).is_some();
self.get_messages_history("default").is_some()
}

#[cfg(feature = "chat-history")]
Expand Down
Loading
Loading