From dd728d04eb91aeeb238c68d769015c80f367cffc Mon Sep 17 00:00:00 2001 From: Ushinnary Date: Sun, 5 May 2024 07:00:39 +0200 Subject: [PATCH 1/9] some rework for chat history --- Cargo.lock | 8 ++-- examples/chat_with_history.rs | 4 +- src/generation/chat/mod.rs | 67 +++++++++++++++++++------------- src/history.rs | 54 ++++++++++++++++--------- tests/chat_history_management.rs | 17 ++++---- tests/send_chat_messages.rs | 24 ++++++------ 6 files changed, 102 insertions(+), 72 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9f61ef8..fe56832 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -799,18 +799,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.198" +version = "1.0.199" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" +checksum = "0c9f6e76df036c77cd94996771fb40db98187f096dd0b9af39c6c6e452ba966a" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.198" +version = "1.0.199" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" +checksum = "11bd257a6541e141e42ca6d24ae26f7714887b47e89aa739099104c7e4d3b7fc" dependencies = [ "proc-macro2", "quote", diff --git a/examples/chat_with_history.rs b/examples/chat_with_history.rs index 6e74d24..69420e9 100644 --- a/examples/chat_with_history.rs +++ b/examples/chat_with_history.rs @@ -27,7 +27,7 @@ async fn main() -> Result<(), Box> { let result = ollama .send_chat_messages_with_history( ChatMessageRequest::new("llama2:latest".to_string(), vec![user_message]), - "default".to_string(), + "default", ) .await?; @@ -37,7 +37,7 @@ async fn main() -> Result<(), Box> { } // Display whole history of messages - dbg!(&ollama.get_messages_history("default".to_string())); + dbg!(&ollama.get_messages_history("default")); Ok(()) } diff --git a/src/generation/chat/mod.rs b/src/generation/chat/mod.rs index 9b8458e..2ecfbd5 100644 --- a/src/generation/chat/mod.rs +++ b/src/generation/chat/mod.rs @@ -104,52 +104,65 @@ impl Ollama { pub async fn send_chat_messages_with_history( &mut self, mut request: ChatMessageRequest, - id: String, + history_id: &str, ) -> crate::error::Result { - let mut current_chat_messages = self.get_chat_messages_by_id(id.clone()); - - if let Some(message) = request.messages.first() { - current_chat_messages.push(message.clone()); - } - // The request is modified to include the current chat messages - request.messages = current_chat_messages.clone(); + request.messages = self.get_prefill_messages(history_id, request.messages.clone()); - let result = self.send_chat_messages(request.clone()).await; + let result = self.send_chat_messages(request).await; - if let Ok(result) = result { - // Message we sent to AI - if let Some(message) = request.messages.last() { - self.store_chat_message_by_id(id.clone(), message.clone()); - } - // AI's response store in the history - self.store_chat_message_by_id(id, result.message.clone().unwrap()); + match result { + Ok(result) => { + // Store AI's response in the history + self.store_chat_message_by_id(history_id, result.message.clone().unwrap()); - return Ok(result); + return Ok(result); + } + Err(_) => { + self.remove_history_last_message(history_id); + } } result } - /// Helper function to get chat messages by id - fn get_chat_messages_by_id(&mut self, id: String) -> Vec { + /// Helper function to store chat messages by id + fn store_chat_message_by_id(&mut self, id: &str, message: ChatMessage) { + if let Some(messages_history) = self.messages_history.as_mut() { + messages_history.add_message(id, message); + } + } + + /// Let get existing history with a new message in it + /// Without impact for existing history + /// Used to prepare history for request + fn get_prefill_messages( + &mut self, + history_id: &str, + request_messages: Vec, + ) -> Vec { let mut backup = MessagesHistory::default(); // Clone the current chat messages to avoid borrowing issues // And not to add message to the history if the request fails - self.messages_history + let current_chat_messages = self + .messages_history .as_mut() .unwrap_or(&mut backup) .messages_by_id - .entry(id.clone()) - .or_default() - .clone() + .entry(history_id.to_string()) + .or_default(); + + if let Some(message) = request_messages.first() { + current_chat_messages.push(message.clone()); + } + + current_chat_messages.clone() } - /// Helper function to store chat messages by id - fn store_chat_message_by_id(&mut self, id: String, message: ChatMessage) { - if let Some(messages_history) = self.messages_history.as_mut() { - messages_history.add_message(id, message); + fn remove_history_last_message(&mut self, history_id: &str) { + if let Some(history) = self.messages_history.as_mut() { + history.pop_last_message_for_id(history_id); } } } diff --git a/src/history.rs b/src/history.rs index 3525822..f27debf 100644 --- a/src/history.rs +++ b/src/history.rs @@ -11,7 +11,9 @@ pub struct MessagesHistory { pub(crate) messages_number_limit: u16, } +/// Store for messages history impl MessagesHistory { + /// Generate a MessagesHistory pub fn new(messages_number_limit: u16) -> Self { Self { messages_by_id: HashMap::new(), @@ -19,8 +21,9 @@ impl MessagesHistory { } } - pub fn add_message(&mut self, entry_id: String, message: ChatMessage) { - let messages = self.messages_by_id.entry(entry_id).or_default(); + /// Add message for entry even no history exists for an entry + pub fn add_message(&mut self, entry_id: &str, message: ChatMessage) { + let messages = self.messages_by_id.entry(entry_id.to_string()).or_default(); // Replacing the oldest message if the limit is reached // The oldest message is the first one, unless it's a system message @@ -40,13 +43,27 @@ impl MessagesHistory { } } + /// Get Option with list of ChatMessage pub fn get_messages(&self, entry_id: &str) -> Option<&Vec> { self.messages_by_id.get(entry_id) } - pub fn clear_messages(&mut self, entry_id: &str) { + /// Clear history for an entry + pub fn clear_messages_for_id(&mut self, entry_id: &str) { self.messages_by_id.remove(entry_id); } + + /// Remove last message added in history + pub fn pop_last_message_for_id(&mut self, entry_id: &str) { + if let Some(messages) = self.messages_by_id.get_mut(entry_id) { + messages.pop(); + } + } + + /// Remove a whole history + pub fn clear_all_messages(&mut self) { + self.messages_by_id = HashMap::new(); + } } impl Ollama { @@ -96,33 +113,32 @@ impl Ollama { } /// Add AI's message to a history - pub fn add_assistant_response(&mut self, entry_id: String, message: String) { - if let Some(messages_history) = self.messages_history.as_mut() { - messages_history.add_message(entry_id, ChatMessage::assistant(message)); - } + pub fn add_assistant_response(&mut self, entry_id: &str, message: String) { + self.add_history_message(entry_id, ChatMessage::assistant(message)); } /// Add user's message to a history - pub fn add_user_response(&mut self, entry_id: String, message: String) { - if let Some(messages_history) = self.messages_history.as_mut() { - messages_history.add_message(entry_id, ChatMessage::user(message)); - } + pub fn add_user_response(&mut self, entry_id: &str, message: String) { + self.add_history_message(entry_id, ChatMessage::user(message)); } /// Set system prompt for chat history - pub fn set_system_response(&mut self, entry_id: String, message: String) { + pub fn set_system_response(&mut self, entry_id: &str, message: String) { + self.add_history_message(entry_id, ChatMessage::system(message)); + } + + /// Helper for message add to history + fn add_history_message(&mut self, entry_id: &str, message: ChatMessage) { if let Some(messages_history) = self.messages_history.as_mut() { - messages_history.add_message(entry_id, ChatMessage::system(message)); + messages_history.add_message(entry_id, message); } } /// For tests purpose /// Getting list of messages in a history - pub fn get_messages_history(&mut self, entry_id: String) -> Option<&Vec> { - if let Some(messages_history) = self.messages_history.as_mut() { - messages_history.messages_by_id.get(&entry_id) - } else { - None - } + pub fn get_messages_history(&mut self, entry_id: &str) -> Option> { + self.messages_history + .clone() + .map(|message_history| message_history.get_messages(entry_id).cloned())? } } diff --git a/tests/chat_history_management.rs b/tests/chat_history_management.rs index 6bbd3e5..9c727df 100644 --- a/tests/chat_history_management.rs +++ b/tests/chat_history_management.rs @@ -5,18 +5,17 @@ fn test_chat_history_saved_as_should() { let mut ollama = Ollama::new_default_with_history(30); let chat_id = "default".to_string(); - ollama.add_user_response(chat_id.clone(), "Hello".to_string()); - ollama.add_assistant_response(chat_id.clone(), "Hi".to_string()); + ollama.add_user_response(&chat_id, "Hello".to_string()); + ollama.add_assistant_response(&chat_id, "Hi".to_string()); - ollama.add_user_response(chat_id.clone(), "Tell me 'hi' again".to_string()); - ollama.add_assistant_response(chat_id.clone(), "Hi again".to_string()); + ollama.add_user_response(&chat_id, "Tell me 'hi' again".to_string()); + ollama.add_assistant_response(&chat_id, "Hi again".to_string()); - assert_eq!( - ollama.get_messages_history(chat_id.clone()).unwrap().len(), - 4 - ); + let history = ollama.get_messages_history(&chat_id).unwrap(); - let last = ollama.get_messages_history(chat_id).unwrap().last(); + assert_eq!(history.len(), 4); + + let last = history.last(); assert!(last.is_some()); assert_eq!(last.unwrap().content, "Hi again".to_string()); } diff --git a/tests/send_chat_messages.rs b/tests/send_chat_messages.rs index 0e3cfd2..bc69f05 100644 --- a/tests/send_chat_messages.rs +++ b/tests/send_chat_messages.rs @@ -64,7 +64,7 @@ async fn test_send_chat_messages_with_history() { let res = ollama .send_chat_messages_with_history( ChatMessageRequest::new("llama2:latest".to_string(), messages.clone()), - id.clone(), + &id, ) .await .unwrap(); @@ -72,22 +72,24 @@ async fn test_send_chat_messages_with_history() { dbg!(&res); assert!(res.done); // Should have user's message as well as AI's response - assert_eq!(ollama.get_messages_history(id.clone()).unwrap().len(), 2); + assert_eq!(ollama.get_messages_history(&id).unwrap().len(), 2); let res = ollama .send_chat_messages_with_history( ChatMessageRequest::new("llama2:latest".to_string(), second_message.clone()), - id.clone(), + &id, ) .await .unwrap(); dbg!(&res); assert!(res.done); + + let history = ollama.get_messages_history(&id).unwrap(); // Should now have 2 user messages as well as AI's responses - assert_eq!(ollama.get_messages_history(id.clone()).unwrap().len(), 4); + assert_eq!(history.len(), 4); - let second_user_message_in_history = ollama.get_messages_history(id.clone()).unwrap().get(2); + let second_user_message_in_history = history.get(2); assert!(second_user_message_in_history.is_some()); assert_eq!( @@ -106,7 +108,7 @@ async fn test_send_chat_messages_remove_old_history_with_limit_less_than_min() { let res = ollama .send_chat_messages_with_history( ChatMessageRequest::new("llama2:latest".to_string(), messages.clone()), - id.clone(), + &id, ) .await .unwrap(); @@ -114,7 +116,7 @@ async fn test_send_chat_messages_remove_old_history_with_limit_less_than_min() { dbg!(&res); assert!(res.done); // Minimal history length is 2 - assert_eq!(ollama.get_messages_history(id.clone()).unwrap().len(), 2); + assert_eq!(ollama.get_messages_history(&id).unwrap().len(), 2); } #[tokio::test] @@ -126,7 +128,7 @@ async fn test_send_chat_messages_remove_old_history() { let res = ollama .send_chat_messages_with_history( ChatMessageRequest::new("llama2:latest".to_string(), messages.clone()), - id.clone(), + &id, ) .await .unwrap(); @@ -135,13 +137,13 @@ async fn test_send_chat_messages_remove_old_history() { assert!(res.done); - assert_eq!(ollama.get_messages_history(id.clone()).unwrap().len(), 2); + assert_eq!(ollama.get_messages_history(&id).unwrap().len(), 2); // Duplicate to check that we have 3 messages stored let res = ollama .send_chat_messages_with_history( ChatMessageRequest::new("llama2:latest".to_string(), messages), - id.clone(), + &id, ) .await .unwrap(); @@ -150,7 +152,7 @@ async fn test_send_chat_messages_remove_old_history() { assert!(res.done); - assert_eq!(ollama.get_messages_history(id.clone()).unwrap().len(), 3); + assert_eq!(ollama.get_messages_history(&id).unwrap().len(), 3); } const IMAGE_URL: &str = "https://images.pexels.com/photos/1054655/pexels-photo-1054655.jpeg"; From f40944bcbfdbec1121d03e0f8ab8c2bc92847545 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 14 May 2024 17:34:11 +0300 Subject: [PATCH 2/9] function calling abilities, multiple tools and base traits for various function calling standards. --- Cargo.lock | 23 ---- Cargo.toml | 1 + src/functions/mod.rs | 53 +++++++++ src/functions/pipelines/mod.rs | 2 + src/functions/pipelines/nous_hermes/mod.rs | 2 + .../pipelines/nous_hermes/parsers.rs | 0 .../pipelines/nous_hermes/prompts.rs | 0 src/functions/pipelines/openai/mod.rs | 92 +++++++++++++++ src/functions/pipelines/openai/parsers.rs | 23 ++++ src/functions/pipelines/openai/prompts.rs | 29 +++++ src/functions/pipelines/openai/request.rs | 66 +++++++++++ src/functions/tools/mod.rs | 59 ++++++++++ src/functions/tools/scraper.rs | 69 +++++++++++ src/functions/tools/search_ddg.rs | 109 ++++++++++++++++++ src/functions/tools/weather.rs | 33 ++++++ src/lib.rs | 2 + 16 files changed, 540 insertions(+), 23 deletions(-) create mode 100644 src/functions/mod.rs create mode 100644 src/functions/pipelines/mod.rs create mode 100644 src/functions/pipelines/nous_hermes/mod.rs create mode 100644 src/functions/pipelines/nous_hermes/parsers.rs create mode 100644 src/functions/pipelines/nous_hermes/prompts.rs create mode 100644 src/functions/pipelines/openai/mod.rs create mode 100644 src/functions/pipelines/openai/parsers.rs create mode 100644 src/functions/pipelines/openai/prompts.rs create mode 100644 src/functions/pipelines/openai/request.rs create mode 100644 src/functions/tools/mod.rs create mode 100644 src/functions/tools/scraper.rs create mode 100644 src/functions/tools/search_ddg.rs create mode 100644 src/functions/tools/weather.rs diff --git a/Cargo.lock b/Cargo.lock index 15d1071..3cf7458 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,28 +39,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "async-stream" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" -dependencies = [ - "async-stream-impl", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.59", -] - [[package]] name = "async-trait" version = "0.1.80" @@ -707,7 +685,6 @@ dependencies = [ name = "ollama-rs" version = "0.2.0" dependencies = [ - "async-stream", "async-trait", "base64", "log", diff --git a/Cargo.toml b/Cargo.toml index 910eecc..7fa08cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,3 +34,4 @@ function-calling = ["scraper", "text-splitter", "regex", "chat-history"] tokio = { version = "1", features = ["full"] } ollama-rs = { path = ".", features = ["stream", "chat-history", "function-calling"] } base64 = "0.22.0" + diff --git a/src/functions/mod.rs b/src/functions/mod.rs new file mode 100644 index 0000000..d40df22 --- /dev/null +++ b/src/functions/mod.rs @@ -0,0 +1,53 @@ +pub mod tools; +pub mod pipelines; + +pub use tools::WeatherTool; +pub use tools::Scraper; +pub use tools::DDGSearcher; + +use async_trait::async_trait; +use serde_json::{Value, json}; +use std::error::Error; +use crate::generation::chat::ChatMessage; + + +pub trait FunctionCallBase: Send + Sync { + fn name(&self) -> String; +} + +#[async_trait] +pub trait FunctionCall: FunctionCallBase { + async fn call(&self, params: Value) -> Result>; +} + +pub struct DefaultFunctionCall {} + +impl FunctionCallBase for DefaultFunctionCall { + fn name(&self) -> String { + "default_function".to_string() + } +} + + +pub fn convert_to_ollama_tool(tool: &dyn crate::generation::functions::tools::Tool) -> Value { + let schema = tool.parameters(); + json!({ + "name": tool.name(), + "properties": schema["properties"], + "required": schema["required"] + }) +} + + +pub fn parse_response(message: &ChatMessage) -> Result { + let content = &message.content; + let value: Value = serde_json::from_str(content).map_err(|e| e.to_string())?; + + if let Some(function_call) = value.get("function_call") { + if let Some(arguments) = function_call.get("arguments") { + return Ok(arguments.to_string()); + } + return Err("`arguments` missing from `function_call`".to_string()); + } + Err("`function_call` missing from `content`".to_string()) +} diff --git a/src/functions/pipelines/mod.rs b/src/functions/pipelines/mod.rs new file mode 100644 index 0000000..727e688 --- /dev/null +++ b/src/functions/pipelines/mod.rs @@ -0,0 +1,2 @@ +pub mod openai; +pub mod nous_hermes; \ No newline at end of file diff --git a/src/functions/pipelines/nous_hermes/mod.rs b/src/functions/pipelines/nous_hermes/mod.rs new file mode 100644 index 0000000..6bf9d61 --- /dev/null +++ b/src/functions/pipelines/nous_hermes/mod.rs @@ -0,0 +1,2 @@ +pub mod prompts; +pub mod parsers; \ No newline at end of file diff --git a/src/functions/pipelines/nous_hermes/parsers.rs b/src/functions/pipelines/nous_hermes/parsers.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/functions/pipelines/nous_hermes/prompts.rs b/src/functions/pipelines/nous_hermes/prompts.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/functions/pipelines/openai/mod.rs b/src/functions/pipelines/openai/mod.rs new file mode 100644 index 0000000..0ef6f78 --- /dev/null +++ b/src/functions/pipelines/openai/mod.rs @@ -0,0 +1,92 @@ +pub mod prompts; +pub mod parsers; +pub mod request; + +pub use prompts::{DEFAULT_SYSTEM_TEMPLATE ,DEFAULT_RESPONSE_FUNCTION}; +pub use request::FunctionCallRequest; +pub use parsers::{generate_system_message, parse_response}; + +use std::sync::Arc; +use async_trait::async_trait; +use serde_json::{json, Value}; +use std::error::Error; +use crate::generation::functions::{FunctionCall, FunctionCallBase}; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; +use crate::generation::chat::request::{ChatMessageRequest}; +use crate::generation::functions::tools::Tool; +use crate::error::OllamaError; + + +pub struct OpenAIFunctionCall { + pub name: String, +} + +impl OpenAIFunctionCall { + pub fn new(name: &str) -> Self { + OpenAIFunctionCall { + name: name.to_string(), + } + } +} + +impl FunctionCallBase for OpenAIFunctionCall { + fn name(&self) -> String { + "openai".to_string() + } +} + +#[async_trait] +impl FunctionCall for OpenAIFunctionCall { + async fn call(&self, params: Value) -> Result> { + // Simulate a function call by returning a simple JSON value + Ok(json!({ "result": format!("Function {} called with params: {}", self.name, params) })) + } +} + + +impl crate::Ollama { + pub async fn function_call_with_history( + &self, + request: ChatMessageRequest, + tool: Arc, + ) -> Result { + let function_call = OpenAIFunctionCall::new(&tool.name()); + let params = tool.parameters(); + let result = function_call.call(params).await?; + Ok(ChatMessageResponse { + model: request.model_name, + created_at: "".to_string(), + message: Some(ChatMessage::assistant(result.to_string())), + done: true, + final_data: None, + }) + } + + pub async fn function_call( + &self, + request: ChatMessageRequest, + ) -> crate::error::Result { + let mut request = request; + request.stream = false; + + let url = format!("{}api/chat", self.url_str()); + let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; + let res = self + .reqwest_client + .post(url) + .body(serialized) + .send() + .await + .map_err(|e| e.to_string())?; + + if !res.status().is_success() { + return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into()); + } + + let bytes = res.bytes().await.map_err(|e| e.to_string())?; + let res = + serde_json::from_slice::(&bytes).map_err(|e| e.to_string())?; + + Ok(res) + } +} \ No newline at end of file diff --git a/src/functions/pipelines/openai/parsers.rs b/src/functions/pipelines/openai/parsers.rs new file mode 100644 index 0000000..980ed95 --- /dev/null +++ b/src/functions/pipelines/openai/parsers.rs @@ -0,0 +1,23 @@ +use crate::generation::chat::ChatMessage; +use serde_json::Value; +use crate::generation::functions::pipelines::openai::DEFAULT_SYSTEM_TEMPLATE; +use crate::generation::functions::tools::Tool; + +pub fn parse_response(message: &ChatMessage) -> Result { + let content = &message.content; + let value: Value = serde_json::from_str(content).map_err(|e| e.to_string())?; + + if let Some(function_call) = value.get("function_call") { + Ok(function_call.clone()) + } else { + Ok(value) + } +} + +pub fn generate_system_message(tools: &[&dyn Tool]) -> ChatMessage { + let tools_info: Vec = tools.iter().map(|tool| tool.parameters()).collect(); + let tools_json = serde_json::to_string(&tools_info).unwrap(); + let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); + ChatMessage::system(system_message_content) +} + diff --git a/src/functions/pipelines/openai/prompts.rs b/src/functions/pipelines/openai/prompts.rs new file mode 100644 index 0000000..94e0749 --- /dev/null +++ b/src/functions/pipelines/openai/prompts.rs @@ -0,0 +1,29 @@ +pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#" +You have access to the following tools: + +{tools} + +You must always select one of the above tools and respond with only a JSON object matching the following schema: + +{ + "tool": , + "tool_input": +} +"#; + +pub const DEFAULT_RESPONSE_FUNCTION: &str = r#" +{ + "name": "__conversational_response", + "description": "Respond conversationally if no other tools should be called for a given query.", + "parameters": { + "type": "object", + "properties": { + "response": { + "type": "string", + "description": "Conversational response to the user." + } + }, + "required": ["response"] + } +} +"#; diff --git a/src/functions/pipelines/openai/request.rs b/src/functions/pipelines/openai/request.rs new file mode 100644 index 0000000..301b6ae --- /dev/null +++ b/src/functions/pipelines/openai/request.rs @@ -0,0 +1,66 @@ +use serde_json::Value; +use std::sync::Arc; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; +use crate::generation::chat::request::{ChatMessageRequest}; +use crate::generation::functions::pipelines::openai::DEFAULT_SYSTEM_TEMPLATE; +use crate::generation::functions::tools::Tool; +use crate::Ollama; +use crate::error::OllamaError; + +#[derive(Clone)] +pub struct FunctionCallRequest { + model_name: String, + tools: Vec>, +} + +impl FunctionCallRequest { + pub fn new(model_name: &str, tools: Vec>) -> Self { + FunctionCallRequest { + model_name: model_name.to_string(), + tools, + } + } + + pub async fn send(&self, ollama: &mut Ollama, input: &str) -> Result { + let system_message = self.get_system_message(); + ollama.send_chat_messages_with_history( + ChatMessageRequest::new(self.model_name.clone(), vec![system_message.clone()]), + "default".to_string(), + ).await?; + + let user_message = ChatMessage::user(input.to_string()); + + let result = ollama + .send_chat_messages_with_history( + ChatMessageRequest::new(self.model_name.clone(), vec![user_message]), + "default".to_string(), + ).await?; + + let response_content = result.message.clone().unwrap().content; + let response_value: Value = match serde_json::from_str(&response_content) { + Ok(value) => value, + Err(e) => return Err(OllamaError::from(e.to_string())), + }; + + if let Some(function_call) = response_value.get("function_call") { + if let Some(tool_name) = function_call.get("tool").and_then(Value::as_str) { + if let Some(tool) = self.tools.iter().find(|t| t.name() == tool_name) { + let result = ollama.function_call_with_history( + ChatMessageRequest::new(self.model_name.clone(), vec![ChatMessage::user(tool_name.to_string())]), + tool.clone(), + ).await?; + return Ok(result); + } + } + } + + Ok(result) + } + + pub fn get_system_message(&self) -> ChatMessage { + let tools_info: Vec = self.tools.iter().map(|tool| tool.parameters()).collect(); + let tools_json = serde_json::to_string(&tools_info).unwrap(); + let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); + ChatMessage::system(system_message_content) + } +} diff --git a/src/functions/tools/mod.rs b/src/functions/tools/mod.rs new file mode 100644 index 0000000..d824fc9 --- /dev/null +++ b/src/functions/tools/mod.rs @@ -0,0 +1,59 @@ +pub mod search_ddg; +pub mod weather; +pub mod scraper; + +pub use self::weather::WeatherTool; +pub use self::scraper::Scraper; +pub use self::search_ddg::DDGSearcher; + +use async_trait::async_trait; +use serde_json::{json, Value}; +use std::error::Error; +use std::string::String; + +#[async_trait] +pub trait Tool: Send + Sync { + /// Returns the name of the tool. + fn name(&self) -> String; + + /// Provides a description of what the tool does and when to use it. + fn description(&self) -> String; + + /// This are the parameters for OpenAI-like function call. + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": self.description() + } + }, + "required": ["input"] + }) + } + + /// Processes an input string and executes the tool's functionality, returning a `Result`. + async fn call(&self, input: &str) -> Result> { + let input = self.parse_input(input).await; + self.run(input).await + } + + /// Executes the core functionality of the tool. + async fn run(&self, input: Value) -> Result>; + + /// Parses the input string. + async fn parse_input(&self, input: &str) -> Value { + log::info!("Using default implementation: {}", input); + match serde_json::from_str::(input) { + Ok(input) => { + if input["input"].is_string() { + Value::String(input["input"].as_str().unwrap().to_string()) + } else { + Value::String(input.to_string()) + } + } + Err(_) => Value::String(input.to_string()), + } + } +} \ No newline at end of file diff --git a/src/functions/tools/scraper.rs b/src/functions/tools/scraper.rs new file mode 100644 index 0000000..2e8b4c0 --- /dev/null +++ b/src/functions/tools/scraper.rs @@ -0,0 +1,69 @@ +use reqwest::Client; +use scraper::{Html, Selector}; +use std::env; +use text_splitter::TextSplitter; + +use std::error::Error; +use serde_json::{Value, json}; +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; + +pub struct Scraper {} + + +#[async_trait] +impl Tool for Scraper { + fn name(&self) -> String { + "Website Scraper".to_string() + } + + fn description(&self) -> String { + "Scrapes text content from websites and splits it into manageable chunks.".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "website": { + "type": "string", + "description": "The URL of the website to scrape" + } + }, + "required": ["website"] + }) + } + + async fn run(&self, input: Value) -> Result> { + let website = input["website"].as_str().ok_or("Website URL is required")?; + let browserless_token = env::var("BROWSERLESS_TOKEN").expect("BROWSERLESS_TOKEN must be set"); + let url = format!("http://0.0.0.0:3000/content?token={}", browserless_token); + let payload = json!({ + "url": website + }); + let client = Client::new(); + let response = client + .post(&url) + .header("cache-control", "no-cache") + .header("content-type", "application/json") + .json(&payload) + .send() + .await?; + + let response_text = response.text().await?; + let document = Html::parse_document(&response_text); + let selector = Selector::parse("p, h1, h2, h3, h4, h5, h6").unwrap(); + let elements: Vec = document + .select(&selector) + .map(|el| el.text().collect::()) + .collect(); + let body = elements.join(" "); + + let splitter = TextSplitter::new(1000); + let chunks = splitter.chunks(&body); + let sentences: Vec = chunks.map(|s| s.to_string()).collect(); + let sentences = sentences.join("\n \n"); + Ok(sentences) + } +} + diff --git a/src/functions/tools/search_ddg.rs b/src/functions/tools/search_ddg.rs new file mode 100644 index 0000000..9407300 --- /dev/null +++ b/src/functions/tools/search_ddg.rs @@ -0,0 +1,109 @@ +use reqwest; + +use url::Url; + +use scraper::{Html, Selector}; +use std::error::Error; + +use crate::generation::functions::tools::Tool; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; +use async_trait::async_trait; + + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResult { + title: String, + link: String, + snippet: String, +} + +impl SearchResult { + fn extract_domain(url: &str) -> Option { + Url::parse(url).ok()?.domain().map(|d| d.to_string()) + } +} + +pub struct DDGSearcher { + pub client: reqwest::Client, + pub base_url: String, +} + +impl DDGSearcher { + pub fn new() -> Self { + DDGSearcher { + client: reqwest::Client::new(), + base_url: "https://duckduckgo.com".to_string(), + } + } + + pub async fn search(&self, query: &str) -> Result, Box> { + let url = format!("{}/html/?q={}", self.base_url, query); + let resp = self.client.get(&url).send().await?; + let body = resp.text().await?; + let document = Html::parse_document(&body); + + let result_selector = Selector::parse(".web-result").unwrap(); + let result_title_selector = Selector::parse(".result__a").unwrap(); + let result_url_selector = Selector::parse(".result__url").unwrap(); + let result_snippet_selector = Selector::parse(".result__snippet").unwrap(); + + let results = document.select(&result_selector).map(|result| { + + let title = result.select(&result_title_selector).next().unwrap().text().collect::>().join(""); + let link = result.select(&result_url_selector).next().unwrap().text().collect::>().join("").trim().to_string(); + let snippet = result.select(&result_snippet_selector).next().unwrap().text().collect::>().join(""); + + SearchResult { + title, + link, + //url: String::from(url.value().attr("href").unwrap()), + snippet, + } + }).collect::>(); + + Ok(results) + } +} + +#[async_trait] +impl Tool for DDGSearcher { + fn name(&self) -> String { + "DDG Searcher".to_string() + } + + fn description(&self) -> String { + "Searches the web using DuckDuckGo's HTML interface.".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "description": "This tool lets you search the web using DuckDuckGo. The input should be a search query.", + "type": "object", + "properties": { + "query": { + "description": "The search query to send to DuckDuckGo", + "type": "string" + } + }, + "required": ["query"] + }) + } + + async fn call(&self, input: &str) -> Result> { + let input_value = self.parse_input(input).await; + self.run(input_value).await + } + + async fn run(&self, input: Value) -> Result> { + let query = input.as_str().ok_or("Input should be a string")?; + let results = self.search(query).await?; + let results_json = serde_json::to_string(&results)?; + Ok(results_json) + } + + async fn parse_input(&self, input: &str) -> Value { + // Use default implementation provided in the Tool trait + Tool::parse_input(self, input).await + } +} \ No newline at end of file diff --git a/src/functions/tools/weather.rs b/src/functions/tools/weather.rs new file mode 100644 index 0000000..2671677 --- /dev/null +++ b/src/functions/tools/weather.rs @@ -0,0 +1,33 @@ +use async_trait::async_trait; +use serde_json::{json, Value}; +use std::error::Error; +use crate::generation::functions::tools::Tool; + +pub struct WeatherTool; + +#[async_trait] +impl Tool for WeatherTool { + fn name(&self) -> String { + "WeatherTool".to_string() + } + + fn description(&self) -> String { + "Get the current weather in a given location.".to_string() + } + + async fn run(&self, input: Value) -> Result> { + let location = input.as_str().ok_or("Input should be a string")?; + let unit = "fahrenheit"; // Default unit + let result = if location.to_lowercase().contains("tokyo") { + json!({"location": "Tokyo", "temperature": "10", "unit": unit}) + } else if location.to_lowercase().contains("san francisco") { + json!({"location": "San Francisco", "temperature": "72", "unit": unit}) + } else if location.to_lowercase().contains("paris") { + json!({"location": "Paris", "temperature": "22", "unit": unit}) + } else { + json!({"location": location, "temperature": "unknown"}) + }; + + Ok(result.to_string()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 633266b..97d63d9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,8 @@ pub mod history; #[cfg(all(feature = "chat-history", feature = "stream"))] pub mod history_async; pub mod models; +#[cfg(feature = "function-calling")] +pub mod functions; use url::Url; From ccdf4b6376dd83d25dd9a2aa6e4d81b74cef5552 Mon Sep 17 00:00:00 2001 From: Ushinnary Date: Thu, 27 Jun 2024 23:05:22 +0200 Subject: [PATCH 3/9] Added stream option with chat history Using code suggested by ZBcheng --- Cargo.lock | 23 ++++ examples/chat_with_history_stream.rs | 25 ++-- src/functions/mod.rs | 53 --------- src/functions/pipelines/mod.rs | 2 - src/functions/pipelines/nous_hermes/mod.rs | 2 - .../pipelines/nous_hermes/parsers.rs | 0 .../pipelines/nous_hermes/prompts.rs | 0 src/functions/pipelines/openai/mod.rs | 92 --------------- src/functions/pipelines/openai/parsers.rs | 23 ---- src/functions/pipelines/openai/prompts.rs | 29 ----- src/functions/pipelines/openai/request.rs | 66 ----------- src/functions/tools/mod.rs | 59 ---------- src/functions/tools/scraper.rs | 69 ----------- src/functions/tools/search_ddg.rs | 109 ------------------ src/functions/tools/weather.rs | 33 ------ src/generation/chat/mod.rs | 79 +++++++++++-- src/generation/functions/mod.rs | 2 +- src/history.rs | 27 +++-- src/lib.rs | 7 +- 19 files changed, 124 insertions(+), 576 deletions(-) delete mode 100644 src/functions/mod.rs delete mode 100644 src/functions/pipelines/mod.rs delete mode 100644 src/functions/pipelines/nous_hermes/mod.rs delete mode 100644 src/functions/pipelines/nous_hermes/parsers.rs delete mode 100644 src/functions/pipelines/nous_hermes/prompts.rs delete mode 100644 src/functions/pipelines/openai/mod.rs delete mode 100644 src/functions/pipelines/openai/parsers.rs delete mode 100644 src/functions/pipelines/openai/prompts.rs delete mode 100644 src/functions/pipelines/openai/request.rs delete mode 100644 src/functions/tools/mod.rs delete mode 100644 src/functions/tools/scraper.rs delete mode 100644 src/functions/tools/search_ddg.rs delete mode 100644 src/functions/tools/weather.rs diff --git a/Cargo.lock b/Cargo.lock index 3cf7458..15d1071 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,6 +39,28 @@ dependencies = [ "memchr", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.59", +] + [[package]] name = "async-trait" version = "0.1.80" @@ -685,6 +707,7 @@ dependencies = [ name = "ollama-rs" version = "0.2.0" dependencies = [ + "async-stream", "async-trait", "base64", "log", diff --git a/examples/chat_with_history_stream.rs b/examples/chat_with_history_stream.rs index 7b00be8..2d54a42 100644 --- a/examples/chat_with_history_stream.rs +++ b/examples/chat_with_history_stream.rs @@ -1,5 +1,5 @@ use ollama_rs::{ - generation::chat::{request::ChatMessageRequest, ChatMessage}, + generation::chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponseStream}, Ollama, }; use tokio::io::{stdout, AsyncWriteExt}; @@ -7,12 +7,10 @@ use tokio_stream::StreamExt; #[tokio::main] async fn main() -> Result<(), Box> { - let mut ollama = Ollama::new_default_with_history_async(30); + let mut ollama = Ollama::new_default_with_history(30); let mut stdout = stdout(); - let chat_id = "default".to_string(); - loop { stdout.write_all(b"\n> ").await?; stdout.flush().await?; @@ -25,12 +23,13 @@ async fn main() -> Result<(), Box> { break; } - let user_message = ChatMessage::user(input.to_string()); - - let mut stream = ollama + let mut stream: ChatMessageResponseStream = ollama .send_chat_messages_with_history_stream( - ChatMessageRequest::new("llama2:latest".to_string(), vec![user_message]), - chat_id.clone(), + ChatMessageRequest::new( + "llama2:latest".to_string(), + vec![ChatMessage::user(input.to_string())], + ), + "user".to_string(), ) .await?; @@ -44,14 +43,8 @@ async fn main() -> Result<(), Box> { response += assistant_message.content.as_str(); } } + dbg!(&ollama.get_messages_history("user")); } - // Display whole history of messages - dbg!( - &ollama - .get_messages_history_async("default".to_string()) - .await - ); - Ok(()) } diff --git a/src/functions/mod.rs b/src/functions/mod.rs deleted file mode 100644 index d40df22..0000000 --- a/src/functions/mod.rs +++ /dev/null @@ -1,53 +0,0 @@ -pub mod tools; -pub mod pipelines; - -pub use tools::WeatherTool; -pub use tools::Scraper; -pub use tools::DDGSearcher; - -use async_trait::async_trait; -use serde_json::{Value, json}; -use std::error::Error; -use crate::generation::chat::ChatMessage; - - -pub trait FunctionCallBase: Send + Sync { - fn name(&self) -> String; -} - -#[async_trait] -pub trait FunctionCall: FunctionCallBase { - async fn call(&self, params: Value) -> Result>; -} - -pub struct DefaultFunctionCall {} - -impl FunctionCallBase for DefaultFunctionCall { - fn name(&self) -> String { - "default_function".to_string() - } -} - - -pub fn convert_to_ollama_tool(tool: &dyn crate::generation::functions::tools::Tool) -> Value { - let schema = tool.parameters(); - json!({ - "name": tool.name(), - "properties": schema["properties"], - "required": schema["required"] - }) -} - - -pub fn parse_response(message: &ChatMessage) -> Result { - let content = &message.content; - let value: Value = serde_json::from_str(content).map_err(|e| e.to_string())?; - - if let Some(function_call) = value.get("function_call") { - if let Some(arguments) = function_call.get("arguments") { - return Ok(arguments.to_string()); - } - return Err("`arguments` missing from `function_call`".to_string()); - } - Err("`function_call` missing from `content`".to_string()) -} diff --git a/src/functions/pipelines/mod.rs b/src/functions/pipelines/mod.rs deleted file mode 100644 index 727e688..0000000 --- a/src/functions/pipelines/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod openai; -pub mod nous_hermes; \ No newline at end of file diff --git a/src/functions/pipelines/nous_hermes/mod.rs b/src/functions/pipelines/nous_hermes/mod.rs deleted file mode 100644 index 6bf9d61..0000000 --- a/src/functions/pipelines/nous_hermes/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod prompts; -pub mod parsers; \ No newline at end of file diff --git a/src/functions/pipelines/nous_hermes/parsers.rs b/src/functions/pipelines/nous_hermes/parsers.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/functions/pipelines/nous_hermes/prompts.rs b/src/functions/pipelines/nous_hermes/prompts.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/functions/pipelines/openai/mod.rs b/src/functions/pipelines/openai/mod.rs deleted file mode 100644 index 0ef6f78..0000000 --- a/src/functions/pipelines/openai/mod.rs +++ /dev/null @@ -1,92 +0,0 @@ -pub mod prompts; -pub mod parsers; -pub mod request; - -pub use prompts::{DEFAULT_SYSTEM_TEMPLATE ,DEFAULT_RESPONSE_FUNCTION}; -pub use request::FunctionCallRequest; -pub use parsers::{generate_system_message, parse_response}; - -use std::sync::Arc; -use async_trait::async_trait; -use serde_json::{json, Value}; -use std::error::Error; -use crate::generation::functions::{FunctionCall, FunctionCallBase}; -use crate::generation::chat::{ChatMessage, ChatMessageResponse}; -use crate::generation::chat::request::{ChatMessageRequest}; -use crate::generation::functions::tools::Tool; -use crate::error::OllamaError; - - -pub struct OpenAIFunctionCall { - pub name: String, -} - -impl OpenAIFunctionCall { - pub fn new(name: &str) -> Self { - OpenAIFunctionCall { - name: name.to_string(), - } - } -} - -impl FunctionCallBase for OpenAIFunctionCall { - fn name(&self) -> String { - "openai".to_string() - } -} - -#[async_trait] -impl FunctionCall for OpenAIFunctionCall { - async fn call(&self, params: Value) -> Result> { - // Simulate a function call by returning a simple JSON value - Ok(json!({ "result": format!("Function {} called with params: {}", self.name, params) })) - } -} - - -impl crate::Ollama { - pub async fn function_call_with_history( - &self, - request: ChatMessageRequest, - tool: Arc, - ) -> Result { - let function_call = OpenAIFunctionCall::new(&tool.name()); - let params = tool.parameters(); - let result = function_call.call(params).await?; - Ok(ChatMessageResponse { - model: request.model_name, - created_at: "".to_string(), - message: Some(ChatMessage::assistant(result.to_string())), - done: true, - final_data: None, - }) - } - - pub async fn function_call( - &self, - request: ChatMessageRequest, - ) -> crate::error::Result { - let mut request = request; - request.stream = false; - - let url = format!("{}api/chat", self.url_str()); - let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .post(url) - .body(serialized) - .send() - .await - .map_err(|e| e.to_string())?; - - if !res.status().is_success() { - return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into()); - } - - let bytes = res.bytes().await.map_err(|e| e.to_string())?; - let res = - serde_json::from_slice::(&bytes).map_err(|e| e.to_string())?; - - Ok(res) - } -} \ No newline at end of file diff --git a/src/functions/pipelines/openai/parsers.rs b/src/functions/pipelines/openai/parsers.rs deleted file mode 100644 index 980ed95..0000000 --- a/src/functions/pipelines/openai/parsers.rs +++ /dev/null @@ -1,23 +0,0 @@ -use crate::generation::chat::ChatMessage; -use serde_json::Value; -use crate::generation::functions::pipelines::openai::DEFAULT_SYSTEM_TEMPLATE; -use crate::generation::functions::tools::Tool; - -pub fn parse_response(message: &ChatMessage) -> Result { - let content = &message.content; - let value: Value = serde_json::from_str(content).map_err(|e| e.to_string())?; - - if let Some(function_call) = value.get("function_call") { - Ok(function_call.clone()) - } else { - Ok(value) - } -} - -pub fn generate_system_message(tools: &[&dyn Tool]) -> ChatMessage { - let tools_info: Vec = tools.iter().map(|tool| tool.parameters()).collect(); - let tools_json = serde_json::to_string(&tools_info).unwrap(); - let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); - ChatMessage::system(system_message_content) -} - diff --git a/src/functions/pipelines/openai/prompts.rs b/src/functions/pipelines/openai/prompts.rs deleted file mode 100644 index 94e0749..0000000 --- a/src/functions/pipelines/openai/prompts.rs +++ /dev/null @@ -1,29 +0,0 @@ -pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#" -You have access to the following tools: - -{tools} - -You must always select one of the above tools and respond with only a JSON object matching the following schema: - -{ - "tool": , - "tool_input": -} -"#; - -pub const DEFAULT_RESPONSE_FUNCTION: &str = r#" -{ - "name": "__conversational_response", - "description": "Respond conversationally if no other tools should be called for a given query.", - "parameters": { - "type": "object", - "properties": { - "response": { - "type": "string", - "description": "Conversational response to the user." - } - }, - "required": ["response"] - } -} -"#; diff --git a/src/functions/pipelines/openai/request.rs b/src/functions/pipelines/openai/request.rs deleted file mode 100644 index 301b6ae..0000000 --- a/src/functions/pipelines/openai/request.rs +++ /dev/null @@ -1,66 +0,0 @@ -use serde_json::Value; -use std::sync::Arc; -use crate::generation::chat::{ChatMessage, ChatMessageResponse}; -use crate::generation::chat::request::{ChatMessageRequest}; -use crate::generation::functions::pipelines::openai::DEFAULT_SYSTEM_TEMPLATE; -use crate::generation::functions::tools::Tool; -use crate::Ollama; -use crate::error::OllamaError; - -#[derive(Clone)] -pub struct FunctionCallRequest { - model_name: String, - tools: Vec>, -} - -impl FunctionCallRequest { - pub fn new(model_name: &str, tools: Vec>) -> Self { - FunctionCallRequest { - model_name: model_name.to_string(), - tools, - } - } - - pub async fn send(&self, ollama: &mut Ollama, input: &str) -> Result { - let system_message = self.get_system_message(); - ollama.send_chat_messages_with_history( - ChatMessageRequest::new(self.model_name.clone(), vec![system_message.clone()]), - "default".to_string(), - ).await?; - - let user_message = ChatMessage::user(input.to_string()); - - let result = ollama - .send_chat_messages_with_history( - ChatMessageRequest::new(self.model_name.clone(), vec![user_message]), - "default".to_string(), - ).await?; - - let response_content = result.message.clone().unwrap().content; - let response_value: Value = match serde_json::from_str(&response_content) { - Ok(value) => value, - Err(e) => return Err(OllamaError::from(e.to_string())), - }; - - if let Some(function_call) = response_value.get("function_call") { - if let Some(tool_name) = function_call.get("tool").and_then(Value::as_str) { - if let Some(tool) = self.tools.iter().find(|t| t.name() == tool_name) { - let result = ollama.function_call_with_history( - ChatMessageRequest::new(self.model_name.clone(), vec![ChatMessage::user(tool_name.to_string())]), - tool.clone(), - ).await?; - return Ok(result); - } - } - } - - Ok(result) - } - - pub fn get_system_message(&self) -> ChatMessage { - let tools_info: Vec = self.tools.iter().map(|tool| tool.parameters()).collect(); - let tools_json = serde_json::to_string(&tools_info).unwrap(); - let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); - ChatMessage::system(system_message_content) - } -} diff --git a/src/functions/tools/mod.rs b/src/functions/tools/mod.rs deleted file mode 100644 index d824fc9..0000000 --- a/src/functions/tools/mod.rs +++ /dev/null @@ -1,59 +0,0 @@ -pub mod search_ddg; -pub mod weather; -pub mod scraper; - -pub use self::weather::WeatherTool; -pub use self::scraper::Scraper; -pub use self::search_ddg::DDGSearcher; - -use async_trait::async_trait; -use serde_json::{json, Value}; -use std::error::Error; -use std::string::String; - -#[async_trait] -pub trait Tool: Send + Sync { - /// Returns the name of the tool. - fn name(&self) -> String; - - /// Provides a description of what the tool does and when to use it. - fn description(&self) -> String; - - /// This are the parameters for OpenAI-like function call. - fn parameters(&self) -> Value { - json!({ - "type": "object", - "properties": { - "input": { - "type": "string", - "description": self.description() - } - }, - "required": ["input"] - }) - } - - /// Processes an input string and executes the tool's functionality, returning a `Result`. - async fn call(&self, input: &str) -> Result> { - let input = self.parse_input(input).await; - self.run(input).await - } - - /// Executes the core functionality of the tool. - async fn run(&self, input: Value) -> Result>; - - /// Parses the input string. - async fn parse_input(&self, input: &str) -> Value { - log::info!("Using default implementation: {}", input); - match serde_json::from_str::(input) { - Ok(input) => { - if input["input"].is_string() { - Value::String(input["input"].as_str().unwrap().to_string()) - } else { - Value::String(input.to_string()) - } - } - Err(_) => Value::String(input.to_string()), - } - } -} \ No newline at end of file diff --git a/src/functions/tools/scraper.rs b/src/functions/tools/scraper.rs deleted file mode 100644 index 2e8b4c0..0000000 --- a/src/functions/tools/scraper.rs +++ /dev/null @@ -1,69 +0,0 @@ -use reqwest::Client; -use scraper::{Html, Selector}; -use std::env; -use text_splitter::TextSplitter; - -use std::error::Error; -use serde_json::{Value, json}; -use crate::generation::functions::tools::Tool; -use async_trait::async_trait; - -pub struct Scraper {} - - -#[async_trait] -impl Tool for Scraper { - fn name(&self) -> String { - "Website Scraper".to_string() - } - - fn description(&self) -> String { - "Scrapes text content from websites and splits it into manageable chunks.".to_string() - } - - fn parameters(&self) -> Value { - json!({ - "type": "object", - "properties": { - "website": { - "type": "string", - "description": "The URL of the website to scrape" - } - }, - "required": ["website"] - }) - } - - async fn run(&self, input: Value) -> Result> { - let website = input["website"].as_str().ok_or("Website URL is required")?; - let browserless_token = env::var("BROWSERLESS_TOKEN").expect("BROWSERLESS_TOKEN must be set"); - let url = format!("http://0.0.0.0:3000/content?token={}", browserless_token); - let payload = json!({ - "url": website - }); - let client = Client::new(); - let response = client - .post(&url) - .header("cache-control", "no-cache") - .header("content-type", "application/json") - .json(&payload) - .send() - .await?; - - let response_text = response.text().await?; - let document = Html::parse_document(&response_text); - let selector = Selector::parse("p, h1, h2, h3, h4, h5, h6").unwrap(); - let elements: Vec = document - .select(&selector) - .map(|el| el.text().collect::()) - .collect(); - let body = elements.join(" "); - - let splitter = TextSplitter::new(1000); - let chunks = splitter.chunks(&body); - let sentences: Vec = chunks.map(|s| s.to_string()).collect(); - let sentences = sentences.join("\n \n"); - Ok(sentences) - } -} - diff --git a/src/functions/tools/search_ddg.rs b/src/functions/tools/search_ddg.rs deleted file mode 100644 index 9407300..0000000 --- a/src/functions/tools/search_ddg.rs +++ /dev/null @@ -1,109 +0,0 @@ -use reqwest; - -use url::Url; - -use scraper::{Html, Selector}; -use std::error::Error; - -use crate::generation::functions::tools::Tool; -use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; -use async_trait::async_trait; - - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SearchResult { - title: String, - link: String, - snippet: String, -} - -impl SearchResult { - fn extract_domain(url: &str) -> Option { - Url::parse(url).ok()?.domain().map(|d| d.to_string()) - } -} - -pub struct DDGSearcher { - pub client: reqwest::Client, - pub base_url: String, -} - -impl DDGSearcher { - pub fn new() -> Self { - DDGSearcher { - client: reqwest::Client::new(), - base_url: "https://duckduckgo.com".to_string(), - } - } - - pub async fn search(&self, query: &str) -> Result, Box> { - let url = format!("{}/html/?q={}", self.base_url, query); - let resp = self.client.get(&url).send().await?; - let body = resp.text().await?; - let document = Html::parse_document(&body); - - let result_selector = Selector::parse(".web-result").unwrap(); - let result_title_selector = Selector::parse(".result__a").unwrap(); - let result_url_selector = Selector::parse(".result__url").unwrap(); - let result_snippet_selector = Selector::parse(".result__snippet").unwrap(); - - let results = document.select(&result_selector).map(|result| { - - let title = result.select(&result_title_selector).next().unwrap().text().collect::>().join(""); - let link = result.select(&result_url_selector).next().unwrap().text().collect::>().join("").trim().to_string(); - let snippet = result.select(&result_snippet_selector).next().unwrap().text().collect::>().join(""); - - SearchResult { - title, - link, - //url: String::from(url.value().attr("href").unwrap()), - snippet, - } - }).collect::>(); - - Ok(results) - } -} - -#[async_trait] -impl Tool for DDGSearcher { - fn name(&self) -> String { - "DDG Searcher".to_string() - } - - fn description(&self) -> String { - "Searches the web using DuckDuckGo's HTML interface.".to_string() - } - - fn parameters(&self) -> Value { - json!({ - "description": "This tool lets you search the web using DuckDuckGo. The input should be a search query.", - "type": "object", - "properties": { - "query": { - "description": "The search query to send to DuckDuckGo", - "type": "string" - } - }, - "required": ["query"] - }) - } - - async fn call(&self, input: &str) -> Result> { - let input_value = self.parse_input(input).await; - self.run(input_value).await - } - - async fn run(&self, input: Value) -> Result> { - let query = input.as_str().ok_or("Input should be a string")?; - let results = self.search(query).await?; - let results_json = serde_json::to_string(&results)?; - Ok(results_json) - } - - async fn parse_input(&self, input: &str) -> Value { - // Use default implementation provided in the Tool trait - Tool::parse_input(self, input).await - } -} \ No newline at end of file diff --git a/src/functions/tools/weather.rs b/src/functions/tools/weather.rs deleted file mode 100644 index 2671677..0000000 --- a/src/functions/tools/weather.rs +++ /dev/null @@ -1,33 +0,0 @@ -use async_trait::async_trait; -use serde_json::{json, Value}; -use std::error::Error; -use crate::generation::functions::tools::Tool; - -pub struct WeatherTool; - -#[async_trait] -impl Tool for WeatherTool { - fn name(&self) -> String { - "WeatherTool".to_string() - } - - fn description(&self) -> String { - "Get the current weather in a given location.".to_string() - } - - async fn run(&self, input: Value) -> Result> { - let location = input.as_str().ok_or("Input should be a string")?; - let unit = "fahrenheit"; // Default unit - let result = if location.to_lowercase().contains("tokyo") { - json!({"location": "Tokyo", "temperature": "10", "unit": unit}) - } else if location.to_lowercase().contains("san francisco") { - json!({"location": "San Francisco", "temperature": "72", "unit": unit}) - } else if location.to_lowercase().contains("paris") { - json!({"location": "Paris", "temperature": "22", "unit": unit}) - } else { - json!({"location": location, "temperature": "unknown"}) - }; - - Ok(result.to_string()) - } -} diff --git a/src/generation/chat/mod.rs b/src/generation/chat/mod.rs index 67b7714..d8bb140 100644 --- a/src/generation/chat/mod.rs +++ b/src/generation/chat/mod.rs @@ -99,6 +99,60 @@ impl Ollama { #[cfg(feature = "chat-history")] impl Ollama { + #[cfg(feature = "stream")] + pub async fn send_chat_messages_with_history_stream( + &mut self, + mut request: ChatMessageRequest, + history_id: String, + ) -> crate::error::Result { + use async_stream::stream; + use tokio_stream::StreamExt; + + request.messages = self.get_prefill_messages(history_id.clone(), request.messages); + + let mut resp_stream: ChatMessageResponseStream = + self.send_chat_messages_stream(request.clone()).await?; + + let (tx, mut rx) = tokio::sync::mpsc::channel::(10); + + let id_copy = history_id.clone(); + let messages_history = self.messages_history.clone(); + + tokio::spawn(async move { + let mut result = String::new(); + + while let Some(item) = rx.recv().await { + if item.done { + if let Some(history) = messages_history.clone() { + let mut inner = history.write().unwrap(); + inner.add_message(id_copy.clone(), ChatMessage::assistant(result)); + } + result = String::new(); + } else { + result.push_str(&item.message.clone().unwrap().content); + } + } + }); + + let messages_history = self.messages_history.clone(); + + let s = stream! { + while let Some(item) = resp_stream.try_next().await.unwrap() { + if let Err(e) = tx.send(item.clone()).await { + eprintln!("Failed to send stream response: {}", e); + if let Some(history) = messages_history.clone() { + let mut inner = history.write().unwrap(); + inner.pop_last_message_for_id(history_id.clone()); + } + }; + + yield Ok(item); + } + }; + + Ok(Box::pin(s)) + } + /// Chat message generation /// Returns a `ChatMessageResponse` object /// Manages the history of messages for the given `id` @@ -130,7 +184,7 @@ impl Ollama { /// Helper function to store chat messages by id fn store_chat_message_by_id>(&mut self, id: S, message: ChatMessage) { if let Some(messages_history) = self.messages_history.as_mut() { - messages_history.add_message(id, message); + messages_history.write().unwrap().add_message(id, message); } } @@ -142,28 +196,33 @@ impl Ollama { history_id: S, request_messages: Vec, ) -> Vec { - let mut backup = MessagesHistory::default(); - + let chat_history = match self.messages_history.as_mut() { + Some(history) => history, + None => &mut { + let new_history = + std::sync::Arc::new(std::sync::RwLock::new(MessagesHistory::default())); + self.messages_history = Some(new_history); + self.messages_history.clone().unwrap() + }, + }; // Clone the current chat messages to avoid borrowing issues // And not to add message to the history if the request fails - let current_chat_messages = self - .messages_history - .as_mut() - .unwrap_or(&mut backup) + let mut history_instance = chat_history.write().unwrap(); + let chat_history = history_instance .messages_by_id .entry(history_id.into()) .or_default(); if let Some(message) = request_messages.first() { - current_chat_messages.push(message.clone()); + chat_history.push(message.clone()); } - current_chat_messages.clone() + chat_history.clone() } fn remove_history_last_message>(&mut self, history_id: S) { if let Some(history) = self.messages_history.as_mut() { - history.pop_last_message_for_id(history_id); + history.write().unwrap().pop_last_message_for_id(history_id); } } } diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs index 56f2bb9..ddec011 100644 --- a/src/generation/functions/mod.rs +++ b/src/generation/functions/mod.rs @@ -24,7 +24,7 @@ impl crate::Ollama { } fn has_system_prompt_history(&mut self) -> bool { - return self.get_messages_history("default".to_string()).is_some(); + return self.get_messages_history("default").is_some(); } #[cfg(feature = "chat-history")] diff --git a/src/history.rs b/src/history.rs index 3650bb5..333890d 100644 --- a/src/history.rs +++ b/src/history.rs @@ -70,7 +70,9 @@ impl Ollama { /// Create default instance with chat history pub fn new_default_with_history(messages_number_limit: u16) -> Self { Self { - messages_history: Some(MessagesHistory::new(messages_number_limit)), + messages_history: Some(std::sync::Arc::new(std::sync::RwLock::new( + MessagesHistory::new(messages_number_limit), + ))), ..Default::default() } } @@ -95,7 +97,9 @@ impl Ollama { pub fn new_with_history_from_url(url: url::Url, messages_number_limit: u16) -> Self { Self { url, - messages_history: Some(MessagesHistory::new(messages_number_limit)), + messages_history: Some(std::sync::Arc::new(std::sync::RwLock::new( + MessagesHistory::new(messages_number_limit), + ))), ..Default::default() } } @@ -107,7 +111,9 @@ impl Ollama { ) -> Result { Ok(Self { url: url.into_url()?, - messages_history: Some(MessagesHistory::new(messages_number_limit)), + messages_history: Some(std::sync::Arc::new(std::sync::RwLock::new( + MessagesHistory::new(messages_number_limit), + ))), ..Default::default() }) } @@ -130,15 +136,22 @@ impl Ollama { /// Helper for message add to history fn add_history_message>(&mut self, entry_id: S, message: ChatMessage) { if let Some(messages_history) = self.messages_history.as_mut() { - messages_history.add_message(entry_id, message); + messages_history + .write() + .unwrap() + .add_message(entry_id, message); } } /// For tests purpose /// Getting list of messages in a history pub fn get_messages_history(&mut self, entry_id: &str) -> Option> { - self.messages_history - .clone() - .map(|message_history| message_history.get_messages(entry_id).cloned())? + self.messages_history.clone().map(|message_history| { + message_history + .write() + .unwrap() + .get_messages(entry_id) + .cloned() + })? } } diff --git a/src/lib.rs b/src/lib.rs index 97d63d9..0f93b88 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,8 +5,6 @@ pub mod history; #[cfg(all(feature = "chat-history", feature = "stream"))] pub mod history_async; pub mod models; -#[cfg(feature = "function-calling")] -pub mod functions; use url::Url; @@ -72,9 +70,8 @@ pub struct Ollama { pub(crate) url: Url, pub(crate) reqwest_client: reqwest::Client, #[cfg(feature = "chat-history")] - pub(crate) messages_history: Option, - #[cfg(all(feature = "chat-history", feature = "stream"))] - pub(crate) messages_history_async: Option, + pub(crate) messages_history: + Option>>, } impl Ollama { From 6d05d90819d8b324eb31a0901e4edf5020264d06 Mon Sep 17 00:00:00 2001 From: Ushinnary Date: Sat, 29 Jun 2024 15:55:26 +0200 Subject: [PATCH 4/9] Simplify stream with history Going back to idea to store responses on success only Remove async history Added tests Improve tests coverage and history management --- src/generation/chat/mod.rs | 178 +++++++------------------------ src/history.rs | 11 +- src/history_async.rs | 141 ------------------------ src/lib.rs | 8 +- tests/chat_history_management.rs | 20 ++++ tests/send_chat_messages.rs | 37 ++++++- 6 files changed, 102 insertions(+), 293 deletions(-) delete mode 100644 src/history_async.rs diff --git a/src/generation/chat/mod.rs b/src/generation/chat/mod.rs index d8bb140..8a5cd21 100644 --- a/src/generation/chat/mod.rs +++ b/src/generation/chat/mod.rs @@ -9,8 +9,6 @@ use request::ChatMessageRequest; #[cfg(feature = "chat-history")] use crate::history::MessagesHistory; -#[cfg(all(feature = "chat-history", feature = "stream"))] -use crate::history_async::MessagesHistoryAsync; #[cfg(feature = "stream")] /// A stream of `ChatMessageResponse` objects @@ -100,51 +98,50 @@ impl Ollama { #[cfg(feature = "chat-history")] impl Ollama { #[cfg(feature = "stream")] - pub async fn send_chat_messages_with_history_stream( + pub async fn send_chat_messages_with_history_stream + Clone>( &mut self, mut request: ChatMessageRequest, - history_id: String, + history_id: S, ) -> crate::error::Result { use async_stream::stream; use tokio_stream::StreamExt; + let id_copy = history_id.clone().into(); + + let mut current_chat_messages = self.get_chat_messages_by_id(id_copy.clone()); + + if let Some(message) = request.messages.first() { + current_chat_messages.push(message.clone()); + } - request.messages = self.get_prefill_messages(history_id.clone(), request.messages); + // The request is modified to include the current chat messages + request.messages.clone_from(¤t_chat_messages); + request.stream = true; let mut resp_stream: ChatMessageResponseStream = self.send_chat_messages_stream(request.clone()).await?; - let (tx, mut rx) = tokio::sync::mpsc::channel::(10); - - let id_copy = history_id.clone(); let messages_history = self.messages_history.clone(); - tokio::spawn(async move { + let s = stream! { let mut result = String::new(); - while let Some(item) = rx.recv().await { + while let Some(item) = resp_stream.try_next().await.unwrap() { + let msg_part = item.clone().message.unwrap().content; + if item.done { if let Some(history) = messages_history.clone() { let mut inner = history.write().unwrap(); - inner.add_message(id_copy.clone(), ChatMessage::assistant(result)); + // Message we sent to AI + if let Some(message) = request.messages.last() { + inner.add_message(id_copy.clone(), message.clone()); + } + + // AI's response + inner.add_message(id_copy.clone(), ChatMessage::assistant(result.clone())); } - result = String::new(); } else { - result.push_str(&item.message.clone().unwrap().content); + result.push_str(&msg_part); } - } - }); - - let messages_history = self.messages_history.clone(); - - let s = stream! { - while let Some(item) = resp_stream.try_next().await.unwrap() { - if let Err(e) = tx.send(item.clone()).await { - eprintln!("Failed to send stream response: {}", e); - if let Some(history) = messages_history.clone() { - let mut inner = history.write().unwrap(); - inner.pop_last_message_for_id(history_id.clone()); - } - }; yield Ok(item); } @@ -162,20 +159,26 @@ impl Ollama { history_id: S, ) -> crate::error::Result { // The request is modified to include the current chat messages - request.messages = self.get_prefill_messages(history_id.clone(), request.messages); + let mut current_chat_messages = self.get_chat_messages_by_id(history_id.clone()); + + if let Some(message) = request.messages.first() { + current_chat_messages.push(message.clone()); + } - let result = self.send_chat_messages(request).await; + // The request is modified to include the current chat messages + request.messages.clone_from(¤t_chat_messages); - match result { - Ok(result) => { - // Store AI's response in the history - self.store_chat_message_by_id(history_id, result.message.clone().unwrap()); + let result = self.send_chat_messages(request.clone()).await; - return Ok(result); - } - Err(_) => { - self.remove_history_last_message(history_id); + if let Ok(result) = result { + // Message we sent to AI + if let Some(message) = request.messages.last() { + self.store_chat_message_by_id(history_id.clone(), message.clone()); } + // Store AI's response in the history + self.store_chat_message_by_id(history_id, result.message.clone().unwrap()); + + return Ok(result); } result @@ -191,10 +194,9 @@ impl Ollama { /// Let get existing history with a new message in it /// Without impact for existing history /// Used to prepare history for request - fn get_prefill_messages>( + fn get_chat_messages_by_id + Clone>( &mut self, history_id: S, - request_messages: Vec, ) -> Vec { let chat_history = match self.messages_history.as_mut() { Some(history) => history, @@ -213,106 +215,8 @@ impl Ollama { .entry(history_id.into()) .or_default(); - if let Some(message) = request_messages.first() { - chat_history.push(message.clone()); - } - chat_history.clone() } - - fn remove_history_last_message>(&mut self, history_id: S) { - if let Some(history) = self.messages_history.as_mut() { - history.write().unwrap().pop_last_message_for_id(history_id); - } - } -} - -#[cfg(all(feature = "chat-history", feature = "stream"))] -impl Ollama { - async fn get_chat_messages_by_id_async(&mut self, id: String) -> Vec { - // Clone the current chat messages to avoid borrowing issues - // And not to add message to the history if the request fails - self.messages_history_async - .as_mut() - .unwrap_or(&mut MessagesHistoryAsync::default()) - .messages_by_id - .lock() - .await - .entry(id.clone()) - .or_default() - .clone() - } - - pub async fn store_chat_message_by_id_async(&mut self, id: String, message: ChatMessage) { - if let Some(messages_history_async) = self.messages_history_async.as_mut() { - messages_history_async.add_message(id, message).await; - } - } - - pub async fn send_chat_messages_with_history_stream( - &mut self, - mut request: ChatMessageRequest, - id: String, - ) -> crate::error::Result { - use tokio_stream::StreamExt; - - let (tx, mut rx) = - tokio::sync::mpsc::unbounded_channel::>(); // create a channel for sending and receiving messages - - let mut current_chat_messages = self.get_chat_messages_by_id_async(id.clone()).await; - - if let Some(messaeg) = request.messages.first() { - current_chat_messages.push(messaeg.clone()); - } - - request.messages.clone_from(¤t_chat_messages); - - let mut stream = self.send_chat_messages_stream(request.clone()).await?; - - let message_history_async = self.messages_history_async.clone(); - - tokio::spawn(async move { - let mut result = String::new(); - while let Some(res) = rx.recv().await { - match res { - Ok(res) => { - if let Some(message) = res.message.clone() { - result += message.content.as_str(); - } - } - Err(_) => { - break; - } - } - } - - if let Some(message_history_async) = message_history_async { - message_history_async - .add_message(id.clone(), ChatMessage::assistant(result)) - .await; - } else { - eprintln!("not using chat-history and stream features"); // this should not happen if the features are enabled - } - }); - - let s = stream! { - while let Some(res) = stream.next().await { - match res { - Ok(res) => { - if let Err(e) = tx.send(Ok(res.clone())) { - eprintln!("Failed to send response: {}", e); - }; - yield Ok(res); - } - Err(_) => { - yield Err(()); - } - } - } - }; - - Ok(Box::pin(s)) - } } #[derive(Debug, Clone, Deserialize, Serialize)] diff --git a/src/history.rs b/src/history.rs index 333890d..a28dc85 100644 --- a/src/history.rs +++ b/src/history.rs @@ -23,6 +23,10 @@ impl MessagesHistory { /// Add message for entry even no history exists for an entry pub fn add_message>(&mut self, entry_id: S, message: ChatMessage) { + if message.content.is_empty() && message.images.is_none() { + return; + } + let messages = self.messages_by_id.entry(entry_id.into()).or_default(); // Replacing the oldest message if the limit is reached @@ -53,13 +57,6 @@ impl MessagesHistory { self.messages_by_id.remove(entry_id); } - /// Remove last message added in history - pub fn pop_last_message_for_id>(&mut self, entry_id: S) { - if let Some(messages) = self.messages_by_id.get_mut(&entry_id.into()) { - messages.pop(); - } - } - /// Remove a whole history pub fn clear_all_messages(&mut self) { self.messages_by_id = HashMap::new(); diff --git a/src/history_async.rs b/src/history_async.rs deleted file mode 100644 index d631aae..0000000 --- a/src/history_async.rs +++ /dev/null @@ -1,141 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; -use tokio::sync::Mutex; - -use crate::{ - generation::chat::{ChatMessage, MessageRole}, - Ollama, -}; - -#[derive(Debug, Clone, Default)] -pub struct MessagesHistoryAsync { - pub(crate) messages_by_id: Arc>>>, - pub(crate) messages_number_limit: u16, -} - -impl MessagesHistoryAsync { - pub fn new(messages_number_limit: u16) -> Self { - Self { - messages_by_id: Arc::new(Mutex::new(HashMap::new())), - messages_number_limit: messages_number_limit.max(2), - } - } - - pub async fn add_message(&self, entry_id: String, message: ChatMessage) { - let mut messages_lock = self.messages_by_id.lock().await; - let messages = messages_lock.entry(entry_id).or_default(); - - // Replacing the oldest message if the limit is reached - // The oldest message is the first one, unless it's a system message - if messages.len() >= self.messages_number_limit as usize { - let index_to_remove = messages - .first() - .map(|m| if m.role == MessageRole::System { 1 } else { 0 }) - .unwrap_or(0); - - messages.remove(index_to_remove); - } - - if message.role == MessageRole::System { - messages.insert(0, message); - } else { - messages.push(message); - } - } - - pub async fn get_messages(&self, entry_id: &str) -> Option> { - let messages_lock = self.messages_by_id.lock().await; - messages_lock.get(entry_id).cloned() - } - - pub async fn clear_messages(&self, entry_id: &str) { - let mut messages_lock = self.messages_by_id.lock().await; - messages_lock.remove(entry_id); - } -} - -impl Ollama { - /// Create default instance with chat history - pub fn new_default_with_history_async(messages_number_limit: u16) -> Self { - Self { - messages_history_async: Some(MessagesHistoryAsync::new(messages_number_limit)), - ..Default::default() - } - } - - /// Create new instance with chat history - /// - /// # Panics - /// - /// Panics if the host is not a valid URL or if the URL cannot have a port. - pub fn new_with_history_async( - host: impl crate::IntoUrl, - port: u16, - messages_number_limit: u16, - ) -> Self { - let mut url = host.into_url().unwrap(); - url.set_port(Some(port)).unwrap(); - Self::new_with_history_from_url(url, messages_number_limit) - } - - /// Create new instance with chat history from a [`url::Url`]. - #[inline] - pub fn new_with_history_from_url_async(url: url::Url, messages_number_limit: u16) -> Self { - Self { - url, - messages_history_async: Some(MessagesHistoryAsync::new(messages_number_limit)), - ..Default::default() - } - } - - #[inline] - pub fn try_new_with_history_async( - url: impl crate::IntoUrl, - messages_number_limit: u16, - ) -> Result { - Ok(Self { - url: url.into_url()?, - messages_history_async: Some(MessagesHistoryAsync::new(messages_number_limit)), - ..Default::default() - }) - } - - /// Add AI's message to a history - pub async fn add_assistant_response_async(&mut self, entry_id: String, message: String) { - if let Some(messages_history) = self.messages_history_async.as_mut() { - messages_history - .add_message(entry_id, ChatMessage::assistant(message)) - .await; - } - } - - /// Add user's message to a history - pub async fn add_user_response_async(&mut self, entry_id: String, message: String) { - if let Some(messages_history) = self.messages_history_async.as_mut() { - messages_history - .add_message(entry_id, ChatMessage::user(message)) - .await; - } - } - - /// Set system prompt for chat history - pub async fn set_system_response_async(&mut self, entry_id: String, message: String) { - if let Some(messages_history) = self.messages_history_async.as_mut() { - messages_history - .add_message(entry_id, ChatMessage::system(message)) - .await; - } - } - - /// For tests purpose - /// Getting list of messages in a history - pub async fn get_messages_history_async( - &mut self, - entry_id: String, - ) -> Option> { - if let Some(messages_history_async) = self.messages_history_async.as_mut() { - messages_history_async.get_messages(&entry_id).await - } else { - None - } - } -} diff --git a/src/lib.rs b/src/lib.rs index 0f93b88..0b6caad 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,11 @@ +use url::Url; + pub mod error; pub mod generation; #[cfg(feature = "chat-history")] pub mod history; -#[cfg(all(feature = "chat-history", feature = "stream"))] -pub mod history_async; pub mod models; -use url::Url; - /// A trait to try to convert some type into a [`Url`]. /// /// This trait is "sealed", such that only types within ollama-rs can @@ -148,8 +146,6 @@ impl Default for Ollama { reqwest_client: reqwest::Client::new(), #[cfg(feature = "chat-history")] messages_history: None, - #[cfg(all(feature = "chat-history", feature = "stream"))] - messages_history_async: None, } } } diff --git a/tests/chat_history_management.rs b/tests/chat_history_management.rs index 5c3ff30..d2d25fe 100644 --- a/tests/chat_history_management.rs +++ b/tests/chat_history_management.rs @@ -19,3 +19,23 @@ fn test_chat_history_saved_as_should() { assert!(last.is_some()); assert_eq!(last.unwrap().content, "Hi again".to_string()); } + +#[test] +fn chat_history_not_stored_if_no_content() { + let mut ollama = Ollama::new_default_with_history(30); + let chat_id = "default"; + + ollama.add_user_response(chat_id, "Hello"); + ollama.add_assistant_response(chat_id, ""); + + ollama.add_user_response(chat_id, ""); + ollama.add_assistant_response(chat_id, "Hi again"); + + let history = ollama.get_messages_history(chat_id).unwrap(); + + assert_eq!(history.len(), 2); + + let last = history.last(); + assert!(last.is_some()); + assert_eq!(last.unwrap().content, "Hi again".to_string()); +} diff --git a/tests/send_chat_messages.rs b/tests/send_chat_messages.rs index bc69f05..bc91800 100644 --- a/tests/send_chat_messages.rs +++ b/tests/send_chat_messages.rs @@ -1,12 +1,13 @@ use base64::Engine; +use tokio_stream::StreamExt; + use ollama_rs::{ generation::{ - chat::{request::ChatMessageRequest, ChatMessage}, + chat::{ChatMessage, request::ChatMessageRequest}, images::Image, }, Ollama, }; -use tokio_stream::StreamExt; #[allow(dead_code)] const PROMPT: &str = "Why is the sky blue?"; @@ -54,6 +55,38 @@ async fn test_send_chat_messages() { assert!(res.done); } +#[tokio::test] +async fn test_send_chat_messages_with_history_stream() { + let mut ollama = Ollama::new_default_with_history(30); + let id = "default".to_string(); + + let messages = vec![ChatMessage::user(PROMPT.to_string())]; + + let mut done = false; + + let mut res = ollama + .send_chat_messages_with_history_stream( + ChatMessageRequest::new("llama2:latest".to_string(), messages), + id.clone(), + ) + .await + .unwrap(); + + while let Some(res) = res.next().await { + let res = res.unwrap(); + + if res.done { + done = true; + break; + } + } + + assert!(done); + // Should have user's message as well as AI's response + dbg!(&ollama.get_messages_history(&id).unwrap()); + assert_eq!(ollama.get_messages_history(&id).unwrap().len(), 2); +} + #[tokio::test] async fn test_send_chat_messages_with_history() { let mut ollama = Ollama::new_default_with_history(30); From f046426b7e910d13cb92bed7532eb816ff5b3cd8 Mon Sep 17 00:00:00 2001 From: Ushinnary Date: Sat, 29 Jun 2024 16:58:42 +0200 Subject: [PATCH 5/9] Cargo format run to fix CI problem --- tests/send_chat_messages.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/send_chat_messages.rs b/tests/send_chat_messages.rs index bc91800..9730d29 100644 --- a/tests/send_chat_messages.rs +++ b/tests/send_chat_messages.rs @@ -3,7 +3,7 @@ use tokio_stream::StreamExt; use ollama_rs::{ generation::{ - chat::{ChatMessage, request::ChatMessageRequest}, + chat::{request::ChatMessageRequest, ChatMessage}, images::Image, }, Ollama, From 2a95ab574b0d0964999dbae8634a0fa24ff35f79 Mon Sep 17 00:00:00 2001 From: Ushinnary Date: Sat, 29 Jun 2024 17:01:50 +0200 Subject: [PATCH 6/9] Fix clippy for CI --- src/generation/functions/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs index ddec011..8fd78c5 100644 --- a/src/generation/functions/mod.rs +++ b/src/generation/functions/mod.rs @@ -24,7 +24,7 @@ impl crate::Ollama { } fn has_system_prompt_history(&mut self) -> bool { - return self.get_messages_history("default").is_some(); + self.get_messages_history("default").is_some() } #[cfg(feature = "chat-history")] From 4ed083d052d48110f72df2067b94ea99321f65c4 Mon Sep 17 00:00:00 2001 From: Ushinnary Date: Sat, 29 Jun 2024 22:31:49 +0200 Subject: [PATCH 7/9] Added wrapper type for history --- src/history.rs | 22 ++++++++-------------- src/lib.rs | 5 +++-- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/history.rs b/src/history.rs index a28dc85..887db59 100644 --- a/src/history.rs +++ b/src/history.rs @@ -11,14 +11,16 @@ pub struct MessagesHistory { pub(crate) messages_number_limit: u16, } +pub type WrappedMessageHistory = std::sync::Arc>; + /// Store for messages history impl MessagesHistory { /// Generate a MessagesHistory - pub fn new(messages_number_limit: u16) -> Self { - Self { + pub fn new(messages_number_limit: u16) -> WrappedMessageHistory { + std::sync::Arc::new(std::sync::RwLock::new(Self { messages_by_id: HashMap::new(), messages_number_limit: messages_number_limit.max(2), - } + })) } /// Add message for entry even no history exists for an entry @@ -67,9 +69,7 @@ impl Ollama { /// Create default instance with chat history pub fn new_default_with_history(messages_number_limit: u16) -> Self { Self { - messages_history: Some(std::sync::Arc::new(std::sync::RwLock::new( - MessagesHistory::new(messages_number_limit), - ))), + messages_history: Some(MessagesHistory::new(messages_number_limit)), ..Default::default() } } @@ -94,10 +94,7 @@ impl Ollama { pub fn new_with_history_from_url(url: url::Url, messages_number_limit: u16) -> Self { Self { url, - messages_history: Some(std::sync::Arc::new(std::sync::RwLock::new( - MessagesHistory::new(messages_number_limit), - ))), - ..Default::default() + ..Ollama::new_default_with_history(messages_number_limit) } } @@ -108,10 +105,7 @@ impl Ollama { ) -> Result { Ok(Self { url: url.into_url()?, - messages_history: Some(std::sync::Arc::new(std::sync::RwLock::new( - MessagesHistory::new(messages_number_limit), - ))), - ..Default::default() + ..Ollama::new_default_with_history(messages_number_limit) }) } diff --git a/src/lib.rs b/src/lib.rs index 0b6caad..6e3369c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "chat-history")] +use crate::history::WrappedMessageHistory; use url::Url; pub mod error; @@ -68,8 +70,7 @@ pub struct Ollama { pub(crate) url: Url, pub(crate) reqwest_client: reqwest::Client, #[cfg(feature = "chat-history")] - pub(crate) messages_history: - Option>>, + pub(crate) messages_history: Option, } impl Ollama { From fc0a771b03e36f6414bc88286a25e17bbc5c51c3 Mon Sep 17 00:00:00 2001 From: Ushinnary Date: Mon, 1 Jul 2024 18:34:33 +0200 Subject: [PATCH 8/9] Remove all Into by imp ToString Thanks again to ZBcheng --- src/generation/chat/mod.rs | 26 ++++++++++++-------------- src/history.rs | 18 +++++++++--------- tests/send_chat_messages.rs | 8 ++++---- 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/src/generation/chat/mod.rs b/src/generation/chat/mod.rs index 8a5cd21..4bf0239 100644 --- a/src/generation/chat/mod.rs +++ b/src/generation/chat/mod.rs @@ -98,14 +98,14 @@ impl Ollama { #[cfg(feature = "chat-history")] impl Ollama { #[cfg(feature = "stream")] - pub async fn send_chat_messages_with_history_stream + Clone>( + pub async fn send_chat_messages_with_history_stream( &mut self, mut request: ChatMessageRequest, - history_id: S, + history_id: impl ToString, ) -> crate::error::Result { use async_stream::stream; use tokio_stream::StreamExt; - let id_copy = history_id.clone().into(); + let id_copy = history_id.to_string().clone(); let mut current_chat_messages = self.get_chat_messages_by_id(id_copy.clone()); @@ -153,13 +153,14 @@ impl Ollama { /// Chat message generation /// Returns a `ChatMessageResponse` object /// Manages the history of messages for the given `id` - pub async fn send_chat_messages_with_history + Clone>( + pub async fn send_chat_messages_with_history( &mut self, mut request: ChatMessageRequest, - history_id: S, + history_id: impl ToString, ) -> crate::error::Result { // The request is modified to include the current chat messages - let mut current_chat_messages = self.get_chat_messages_by_id(history_id.clone()); + let id_copy = history_id.to_string().clone(); + let mut current_chat_messages = self.get_chat_messages_by_id(id_copy.clone()); if let Some(message) = request.messages.first() { current_chat_messages.push(message.clone()); @@ -173,10 +174,10 @@ impl Ollama { if let Ok(result) = result { // Message we sent to AI if let Some(message) = request.messages.last() { - self.store_chat_message_by_id(history_id.clone(), message.clone()); + self.store_chat_message_by_id(id_copy.clone(), message.clone()); } // Store AI's response in the history - self.store_chat_message_by_id(history_id, result.message.clone().unwrap()); + self.store_chat_message_by_id(id_copy, result.message.clone().unwrap()); return Ok(result); } @@ -185,7 +186,7 @@ impl Ollama { } /// Helper function to store chat messages by id - fn store_chat_message_by_id>(&mut self, id: S, message: ChatMessage) { + fn store_chat_message_by_id(&mut self, id: impl ToString, message: ChatMessage) { if let Some(messages_history) = self.messages_history.as_mut() { messages_history.write().unwrap().add_message(id, message); } @@ -194,10 +195,7 @@ impl Ollama { /// Let get existing history with a new message in it /// Without impact for existing history /// Used to prepare history for request - fn get_chat_messages_by_id + Clone>( - &mut self, - history_id: S, - ) -> Vec { + fn get_chat_messages_by_id(&mut self, history_id: impl ToString) -> Vec { let chat_history = match self.messages_history.as_mut() { Some(history) => history, None => &mut { @@ -212,7 +210,7 @@ impl Ollama { let mut history_instance = chat_history.write().unwrap(); let chat_history = history_instance .messages_by_id - .entry(history_id.into()) + .entry(history_id.to_string()) .or_default(); chat_history.clone() diff --git a/src/history.rs b/src/history.rs index 887db59..157557b 100644 --- a/src/history.rs +++ b/src/history.rs @@ -24,12 +24,12 @@ impl MessagesHistory { } /// Add message for entry even no history exists for an entry - pub fn add_message>(&mut self, entry_id: S, message: ChatMessage) { + pub fn add_message(&mut self, entry_id: impl ToString, message: ChatMessage) { if message.content.is_empty() && message.images.is_none() { return; } - let messages = self.messages_by_id.entry(entry_id.into()).or_default(); + let messages = self.messages_by_id.entry(entry_id.to_string()).or_default(); // Replacing the oldest message if the limit is reached // The oldest message is the first one, unless it's a system message @@ -110,22 +110,22 @@ impl Ollama { } /// Add AI's message to a history - pub fn add_assistant_response>(&mut self, entry_id: S, message: S) { - self.add_history_message(entry_id, ChatMessage::assistant(message.into())); + pub fn add_assistant_response(&mut self, entry_id: impl ToString, message: impl ToString) { + self.add_history_message(entry_id, ChatMessage::assistant(message.to_string())); } /// Add user's message to a history - pub fn add_user_response>(&mut self, entry_id: S, message: S) { - self.add_history_message(entry_id, ChatMessage::user(message.into())); + pub fn add_user_response(&mut self, entry_id: impl ToString, message: impl ToString) { + self.add_history_message(entry_id, ChatMessage::user(message.to_string())); } /// Set system prompt for chat history - pub fn set_system_response>(&mut self, entry_id: S, message: S) { - self.add_history_message(entry_id, ChatMessage::system(message.into())); + pub fn set_system_response(&mut self, entry_id: impl ToString, message: impl ToString) { + self.add_history_message(entry_id, ChatMessage::system(message.to_string())); } /// Helper for message add to history - fn add_history_message>(&mut self, entry_id: S, message: ChatMessage) { + fn add_history_message(&mut self, entry_id: impl ToString, message: ChatMessage) { if let Some(messages_history) = self.messages_history.as_mut() { messages_history .write() diff --git a/tests/send_chat_messages.rs b/tests/send_chat_messages.rs index 9730d29..ccd72c3 100644 --- a/tests/send_chat_messages.rs +++ b/tests/send_chat_messages.rs @@ -58,7 +58,7 @@ async fn test_send_chat_messages() { #[tokio::test] async fn test_send_chat_messages_with_history_stream() { let mut ollama = Ollama::new_default_with_history(30); - let id = "default".to_string(); + let id = "default"; let messages = vec![ChatMessage::user(PROMPT.to_string())]; @@ -67,7 +67,7 @@ async fn test_send_chat_messages_with_history_stream() { let mut res = ollama .send_chat_messages_with_history_stream( ChatMessageRequest::new("llama2:latest".to_string(), messages), - id.clone(), + id, ) .await .unwrap(); @@ -83,8 +83,8 @@ async fn test_send_chat_messages_with_history_stream() { assert!(done); // Should have user's message as well as AI's response - dbg!(&ollama.get_messages_history(&id).unwrap()); - assert_eq!(ollama.get_messages_history(&id).unwrap().len(), 2); + dbg!(&ollama.get_messages_history(id).unwrap()); + assert_eq!(ollama.get_messages_history(id).unwrap().len(), 2); } #[tokio::test] From 9ad1a09ca7001bb934823c5384140192c40c750d Mon Sep 17 00:00:00 2001 From: Ushinnary Date: Sun, 7 Jul 2024 20:47:49 +0200 Subject: [PATCH 9/9] Update readme, add clear history feature, tests update --- README.md | 43 ++++++++++++++++++++++++++++ src/history.rs | 24 ++++++++++++---- tests/chat_history_management.rs | 49 ++++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 8664f69..d7272f8 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,49 @@ if let Ok(res) = res { **OUTPUTS:** _1. Sun emits white sunlight: The sun consists primarily ..._ +### Chat mode +Description: _Every message sent and received will be stored in library's history._ +_Each time you want to store history, you have to provide an ID for a chat._ +_It can be uniq for each user or the same every time, depending on your need_ + +Example with history: +```rust +let model = "llama2:latest".to_string(); +let prompt = "Why is the sky blue?".to_string(); +let history_id = "USER_ID_OR_WHATEVER"; + +let res = ollama + .send_chat_messages_with_history( + ChatMessageRequest::new( + model, + vec![ChatMessage::user(prompt)], // <- You should provide only one message + ), + history_id // <- This entry save for us all the history + ).await; + +if let Ok(res) = res { +println!("{}", res.response); +} +``` + +Getting history for some ID: +```rust +let history_id = "USER_ID_OR_WHATEVER"; +let history = ollama.get_message_history(history_id); // <- Option> +// Act +``` + +Clear history if we no more need it: +```rust +// Clear history for an ID +let history_id = "USER_ID_OR_WHATEVER"; +ollama.clear_messages_for_id(history_id); +// Clear history for all chats +ollama.clear_all_messages(); +``` + +_Check chat with history examples for [default](https://github.com/pepperoni21/ollama-rs/blob/master/examples/chat_with_history.rs) and [stream](https://github.com/pepperoni21/ollama-rs/blob/master/examples/chat_with_history_stream.rs)_ + ### List local models ```rust diff --git a/src/history.rs b/src/history.rs index 157557b..2fe3eea 100644 --- a/src/history.rs +++ b/src/history.rs @@ -50,13 +50,13 @@ impl MessagesHistory { } /// Get Option with list of ChatMessage - pub fn get_messages(&self, entry_id: &str) -> Option<&Vec> { - self.messages_by_id.get(entry_id) + pub fn get_messages(&self, entry_id: impl ToString) -> Option<&Vec> { + self.messages_by_id.get(&entry_id.to_string()) } /// Clear history for an entry - pub fn clear_messages_for_id(&mut self, entry_id: &str) { - self.messages_by_id.remove(entry_id); + pub fn clear_messages_for_id(&mut self, entry_id: impl ToString) { + self.messages_by_id.remove(&entry_id.to_string()); } /// Remove a whole history @@ -136,7 +136,7 @@ impl Ollama { /// For tests purpose /// Getting list of messages in a history - pub fn get_messages_history(&mut self, entry_id: &str) -> Option> { + pub fn get_messages_history(&mut self, entry_id: impl ToString) -> Option> { self.messages_history.clone().map(|message_history| { message_history .write() @@ -145,4 +145,18 @@ impl Ollama { .cloned() })? } + + /// Clear history for an entry + pub fn clear_messages_for_id(&mut self, entry_id: impl ToString) { + if let Some(history) = self.messages_history.clone() { + history.write().unwrap().clear_messages_for_id(entry_id) + } + } + + /// Remove a whole history + pub fn clear_all_messages(&mut self) { + if let Some(history) = self.messages_history.clone() { + history.write().unwrap().clear_all_messages() + } + } } diff --git a/tests/chat_history_management.rs b/tests/chat_history_management.rs index d2d25fe..cd98d03 100644 --- a/tests/chat_history_management.rs +++ b/tests/chat_history_management.rs @@ -39,3 +39,52 @@ fn chat_history_not_stored_if_no_content() { assert!(last.is_some()); assert_eq!(last.unwrap().content, "Hi again".to_string()); } + +#[test] +fn clear_chat_history_for_one_id_only() { + let mut ollama = Ollama::new_default_with_history(30); + let first_chat_id = "default"; + + ollama.add_user_response(first_chat_id, "Hello"); + + let another_chat_id = "not_default"; + + ollama.add_user_response(another_chat_id, "Hello"); + + assert_eq!(ollama.get_messages_history(first_chat_id).unwrap().len(), 1); + assert_eq!( + ollama.get_messages_history(another_chat_id).unwrap().len(), + 1 + ); + + ollama.clear_messages_for_id(first_chat_id); + + assert!(ollama.get_messages_history(first_chat_id).is_none()); + assert_eq!( + ollama.get_messages_history(another_chat_id).unwrap().len(), + 1 + ); +} + +#[test] +fn clear_chat_history_for_all() { + let mut ollama = Ollama::new_default_with_history(30); + let first_chat_id = "default"; + + ollama.add_user_response(first_chat_id, "Hello"); + + let another_chat_id = "not_default"; + + ollama.add_user_response(another_chat_id, "Hello"); + + assert_eq!(ollama.get_messages_history(first_chat_id).unwrap().len(), 1); + assert_eq!( + ollama.get_messages_history(another_chat_id).unwrap().len(), + 1 + ); + + ollama.clear_all_messages(); + + assert!(ollama.get_messages_history(first_chat_id).is_none()); + assert!(ollama.get_messages_history(another_chat_id).is_none()); +}