Skip to content

Commit

Permalink
Merge pull request #56 from ushinnary/store_chat_history_update
Browse files Browse the repository at this point in the history
Store chat history with stream update
  • Loading branch information
pepperoni21 authored Jul 11, 2024
2 parents 953232a + 9ad1a09 commit c5f6928
Show file tree
Hide file tree
Showing 12 changed files with 332 additions and 329 deletions.
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ function-calling = ["scraper", "text-splitter", "regex", "chat-history"]
tokio = { version = "1", features = ["full"] }
ollama-rs = { path = ".", features = ["stream", "chat-history", "function-calling"] }
base64 = "0.22.0"

43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,49 @@ if let Ok(res) = res {

**OUTPUTS:** _1. Sun emits white sunlight: The sun consists primarily ..._

### Chat mode
Description: _Every message sent and received will be stored in library's history._
_Each time you want to store history, you have to provide an ID for a chat._
_It can be uniq for each user or the same every time, depending on your need_

Example with history:
```rust
let model = "llama2:latest".to_string();
let prompt = "Why is the sky blue?".to_string();
let history_id = "USER_ID_OR_WHATEVER";

let res = ollama
.send_chat_messages_with_history(
ChatMessageRequest::new(
model,
vec![ChatMessage::user(prompt)], // <- You should provide only one message
),
history_id // <- This entry save for us all the history
).await;

if let Ok(res) = res {
println!("{}", res.response);
}
```

Getting history for some ID:
```rust
let history_id = "USER_ID_OR_WHATEVER";
let history = ollama.get_message_history(history_id); // <- Option<Vec<ChatMessage>>
// Act
```

Clear history if we no more need it:
```rust
// Clear history for an ID
let history_id = "USER_ID_OR_WHATEVER";
ollama.clear_messages_for_id(history_id);
// Clear history for all chats
ollama.clear_all_messages();
```

_Check chat with history examples for [default](https://github.com/pepperoni21/ollama-rs/blob/master/examples/chat_with_history.rs) and [stream](https://github.com/pepperoni21/ollama-rs/blob/master/examples/chat_with_history_stream.rs)_

### List local models

```rust
Expand Down
4 changes: 2 additions & 2 deletions examples/chat_with_history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let result = ollama
.send_chat_messages_with_history(
ChatMessageRequest::new("llama2:latest".to_string(), vec![user_message]),
"default".to_string(),
"default",
)
.await?;

Expand All @@ -37,7 +37,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}

// Display whole history of messages
dbg!(&ollama.get_messages_history("default".to_string()));
dbg!(&ollama.get_messages_history("default"));

Ok(())
}
25 changes: 9 additions & 16 deletions examples/chat_with_history_stream.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
use ollama_rs::{
generation::chat::{request::ChatMessageRequest, ChatMessage},
generation::chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponseStream},
Ollama,
};
use tokio::io::{stdout, AsyncWriteExt};
use tokio_stream::StreamExt;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut ollama = Ollama::new_default_with_history_async(30);
let mut ollama = Ollama::new_default_with_history(30);

let mut stdout = stdout();

let chat_id = "default".to_string();

loop {
stdout.write_all(b"\n> ").await?;
stdout.flush().await?;
Expand All @@ -25,12 +23,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
break;
}

let user_message = ChatMessage::user(input.to_string());

let mut stream = ollama
let mut stream: ChatMessageResponseStream = ollama
.send_chat_messages_with_history_stream(
ChatMessageRequest::new("llama2:latest".to_string(), vec![user_message]),
chat_id.clone(),
ChatMessageRequest::new(
"llama2:latest".to_string(),
vec![ChatMessage::user(input.to_string())],
),
"user".to_string(),
)
.await?;

Expand All @@ -44,14 +43,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
response += assistant_message.content.as_str();
}
}
dbg!(&ollama.get_messages_history("user"));
}

// Display whole history of messages
dbg!(
&ollama
.get_messages_history_async("default".to_string())
.await
);

Ok(())
}
186 changes: 80 additions & 106 deletions src/generation/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ use request::ChatMessageRequest;

#[cfg(feature = "chat-history")]
use crate::history::MessagesHistory;
#[cfg(all(feature = "chat-history", feature = "stream"))]
use crate::history_async::MessagesHistoryAsync;

#[cfg(feature = "stream")]
/// A stream of `ChatMessageResponse` objects
Expand Down Expand Up @@ -99,15 +97,70 @@ impl Ollama {

#[cfg(feature = "chat-history")]
impl Ollama {
#[cfg(feature = "stream")]
pub async fn send_chat_messages_with_history_stream(
&mut self,
mut request: ChatMessageRequest,
history_id: impl ToString,
) -> crate::error::Result<ChatMessageResponseStream> {
use async_stream::stream;
use tokio_stream::StreamExt;
let id_copy = history_id.to_string().clone();

let mut current_chat_messages = self.get_chat_messages_by_id(id_copy.clone());

if let Some(message) = request.messages.first() {
current_chat_messages.push(message.clone());
}

// The request is modified to include the current chat messages
request.messages.clone_from(&current_chat_messages);
request.stream = true;

let mut resp_stream: ChatMessageResponseStream =
self.send_chat_messages_stream(request.clone()).await?;

let messages_history = self.messages_history.clone();

let s = stream! {
let mut result = String::new();

while let Some(item) = resp_stream.try_next().await.unwrap() {
let msg_part = item.clone().message.unwrap().content;

if item.done {
if let Some(history) = messages_history.clone() {
let mut inner = history.write().unwrap();
// Message we sent to AI
if let Some(message) = request.messages.last() {
inner.add_message(id_copy.clone(), message.clone());
}

// AI's response
inner.add_message(id_copy.clone(), ChatMessage::assistant(result.clone()));
}
} else {
result.push_str(&msg_part);
}

yield Ok(item);
}
};

Ok(Box::pin(s))
}

/// Chat message generation
/// Returns a `ChatMessageResponse` object
/// Manages the history of messages for the given `id`
pub async fn send_chat_messages_with_history(
&mut self,
mut request: ChatMessageRequest,
id: String,
history_id: impl ToString,
) -> crate::error::Result<ChatMessageResponse> {
let mut current_chat_messages = self.get_chat_messages_by_id(id.clone());
// The request is modified to include the current chat messages
let id_copy = history_id.to_string().clone();
let mut current_chat_messages = self.get_chat_messages_by_id(id_copy.clone());

if let Some(message) = request.messages.first() {
current_chat_messages.push(message.clone());
Expand All @@ -121,125 +174,46 @@ impl Ollama {
if let Ok(result) = result {
// Message we sent to AI
if let Some(message) = request.messages.last() {
self.store_chat_message_by_id(id.clone(), message.clone());
self.store_chat_message_by_id(id_copy.clone(), message.clone());
}
// AI's response store in the history
self.store_chat_message_by_id(id, result.message.clone().unwrap());
// Store AI's response in the history
self.store_chat_message_by_id(id_copy, result.message.clone().unwrap());

return Ok(result);
}

result
}

/// Helper function to get chat messages by id
fn get_chat_messages_by_id(&mut self, id: String) -> Vec<ChatMessage> {
let mut backup = MessagesHistory::default();

// Clone the current chat messages to avoid borrowing issues
// And not to add message to the history if the request fails
self.messages_history
.as_mut()
.unwrap_or(&mut backup)
.messages_by_id
.entry(id.clone())
.or_default()
.clone()
}

/// Helper function to store chat messages by id
fn store_chat_message_by_id(&mut self, id: String, message: ChatMessage) {
fn store_chat_message_by_id(&mut self, id: impl ToString, message: ChatMessage) {
if let Some(messages_history) = self.messages_history.as_mut() {
messages_history.add_message(id, message);
messages_history.write().unwrap().add_message(id, message);
}
}
}

#[cfg(all(feature = "chat-history", feature = "stream"))]
impl Ollama {
async fn get_chat_messages_by_id_async(&mut self, id: String) -> Vec<ChatMessage> {
/// Let get existing history with a new message in it
/// Without impact for existing history
/// Used to prepare history for request
fn get_chat_messages_by_id(&mut self, history_id: impl ToString) -> Vec<ChatMessage> {
let chat_history = match self.messages_history.as_mut() {
Some(history) => history,
None => &mut {
let new_history =
std::sync::Arc::new(std::sync::RwLock::new(MessagesHistory::default()));
self.messages_history = Some(new_history);
self.messages_history.clone().unwrap()
},
};
// Clone the current chat messages to avoid borrowing issues
// And not to add message to the history if the request fails
self.messages_history_async
.as_mut()
.unwrap_or(&mut MessagesHistoryAsync::default())
let mut history_instance = chat_history.write().unwrap();
let chat_history = history_instance
.messages_by_id
.lock()
.await
.entry(id.clone())
.or_default()
.clone()
}

pub async fn store_chat_message_by_id_async(&mut self, id: String, message: ChatMessage) {
if let Some(messages_history_async) = self.messages_history_async.as_mut() {
messages_history_async.add_message(id, message).await;
}
}

pub async fn send_chat_messages_with_history_stream(
&mut self,
mut request: ChatMessageRequest,
id: String,
) -> crate::error::Result<ChatMessageResponseStream> {
use tokio_stream::StreamExt;

let (tx, mut rx) =
tokio::sync::mpsc::unbounded_channel::<Result<ChatMessageResponse, ()>>(); // create a channel for sending and receiving messages

let mut current_chat_messages = self.get_chat_messages_by_id_async(id.clone()).await;

if let Some(messaeg) = request.messages.first() {
current_chat_messages.push(messaeg.clone());
}

request.messages.clone_from(&current_chat_messages);

let mut stream = self.send_chat_messages_stream(request.clone()).await?;
.entry(history_id.to_string())
.or_default();

let message_history_async = self.messages_history_async.clone();

tokio::spawn(async move {
let mut result = String::new();
while let Some(res) = rx.recv().await {
match res {
Ok(res) => {
if let Some(message) = res.message.clone() {
result += message.content.as_str();
}
}
Err(_) => {
break;
}
}
}

if let Some(message_history_async) = message_history_async {
message_history_async
.add_message(id.clone(), ChatMessage::assistant(result))
.await;
} else {
eprintln!("not using chat-history and stream features"); // this should not happen if the features are enabled
}
});

let s = stream! {
while let Some(res) = stream.next().await {
match res {
Ok(res) => {
if let Err(e) = tx.send(Ok(res.clone())) {
eprintln!("Failed to send response: {}", e);
};
yield Ok(res);
}
Err(_) => {
yield Err(());
}
}
}
};

Ok(Box::pin(s))
chat_history.clone()
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/generation/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl crate::Ollama {
}

fn has_system_prompt_history(&mut self) -> bool {
return self.get_messages_history("default".to_string()).is_some();
self.get_messages_history("default").is_some()
}

#[cfg(feature = "chat-history")]
Expand Down
Loading

0 comments on commit c5f6928

Please sign in to comment.