diff --git a/.gitignore b/.gitignore index 83cdd23..3e9f11a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target .vscode/settings.json shell.nix +.idea diff --git a/README.md b/README.md index d7272f8..77e53fa 100644 --- a/README.md +++ b/README.md @@ -194,7 +194,7 @@ _Returns a `GenerateEmbeddingsResponse` struct containing the embeddings (a vect ### Make a function call ```rust -let tools = vec![Arc::new(Scraper::new())]; +let tools = vec![Arc::new(Scraper::new()), Arc::new(DDGSearcher::new())]; let parser = Arc::new(NousFunctionCall::new()); let message = ChatMessage::user("What is the current oil price?".to_string()); let res = ollama.send_function_call( diff --git a/src/generation/chat/mod.rs b/src/generation/chat/mod.rs index 4bf0239..7fdceb0 100644 --- a/src/generation/chat/mod.rs +++ b/src/generation/chat/mod.rs @@ -196,14 +196,15 @@ impl Ollama { /// Without impact for existing history /// Used to prepare history for request fn get_chat_messages_by_id(&mut self, history_id: impl ToString) -> Vec { + let mut binding = { + let new_history = + std::sync::Arc::new(std::sync::RwLock::new(MessagesHistory::default())); + self.messages_history = Some(new_history); + self.messages_history.clone().unwrap() + }; let chat_history = match self.messages_history.as_mut() { Some(history) => history, - None => &mut { - let new_history = - std::sync::Arc::new(std::sync::RwLock::new(MessagesHistory::default())); - self.messages_history = Some(new_history); - self.messages_history.clone().unwrap() - }, + None => &mut binding, }; // Clone the current chat messages to avoid borrowing issues // And not to add message to the history if the request fails diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs index 8fd78c5..2e761c2 100644 --- a/src/generation/functions/mod.rs +++ b/src/generation/functions/mod.rs @@ -2,11 +2,14 @@ pub mod pipelines; pub mod request; pub mod tools; +pub use crate::generation::functions::pipelines::meta_llama::request::LlamaFunctionCall; pub use crate::generation::functions::pipelines::nous_hermes::request::NousFunctionCall; pub use crate::generation::functions::pipelines::openai::request::OpenAIFunctionCall; pub use crate::generation::functions::request::FunctionCallRequest; +pub use tools::Browserless; pub use tools::DDGSearcher; pub use tools::Scraper; +pub use tools::SerperSearchTool; pub use tools::StockScraper; use crate::error::OllamaError; diff --git a/src/generation/functions/pipelines/meta_llama/mod.rs b/src/generation/functions/pipelines/meta_llama/mod.rs new file mode 100644 index 0000000..23c05af --- /dev/null +++ b/src/generation/functions/pipelines/meta_llama/mod.rs @@ -0,0 +1,4 @@ +pub mod prompts; +pub mod request; + +pub use prompts::DEFAULT_SYSTEM_TEMPLATE; diff --git a/src/generation/functions/pipelines/meta_llama/prompts.rs b/src/generation/functions/pipelines/meta_llama/prompts.rs new file mode 100644 index 0000000..fcb1a6d --- /dev/null +++ b/src/generation/functions/pipelines/meta_llama/prompts.rs @@ -0,0 +1,14 @@ +pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#" +You have access to the following functions: +{tools} +If you choose to call a function ONLY reply in the following format with no prefix or suffix: + +{{\"example_name\": \"example_value\"}} + +Reminder: +- Function calls MUST follow the specified format, start with +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +"#; diff --git a/src/generation/functions/pipelines/meta_llama/request.rs b/src/generation/functions/pipelines/meta_llama/request.rs new file mode 100644 index 0000000..8ed2ebd --- /dev/null +++ b/src/generation/functions/pipelines/meta_llama/request.rs @@ -0,0 +1,136 @@ +use crate::error::OllamaError; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; +use crate::generation::functions::pipelines::meta_llama::DEFAULT_SYSTEM_TEMPLATE; +use crate::generation::functions::pipelines::RequestParserBase; +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; +use regex::Regex; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::sync::Arc; + +pub fn convert_to_llama_tool(tool: &Arc) -> Value { + let mut function = HashMap::new(); + function.insert("name".to_string(), Value::String(tool.name())); + function.insert("description".to_string(), Value::String(tool.description())); + function.insert("parameters".to_string(), tool.parameters()); + json!(format!( + "Use the function '{name}' to '{description}': {json}", + name = tool.name(), + description = tool.description(), + json = json!(function) + )) +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct LlamaFunctionCallSignature { + pub function: String, //name of the tool + pub arguments: Value, +} + +pub struct LlamaFunctionCall {} + +impl LlamaFunctionCall { + pub async fn function_call_with_history( + &self, + model_name: String, + tool_params: Value, + tool: Arc, + ) -> Result { + let result = tool.run(tool_params).await; + match result { + Ok(result) => Ok(ChatMessageResponse { + model: model_name.clone(), + created_at: "".to_string(), + message: Some(ChatMessage::assistant(result.to_string())), + done: true, + final_data: None, + }), + Err(e) => Err(self.error_handler(OllamaError::from(e))), + } + } + + fn clean_tool_call(&self, json_str: &str) -> String { + json_str + .trim() + .trim_start_matches("```json") + .trim_end_matches("```") + .trim() + .to_string() + .replace("{{", "{") + .replace("}}", "}") + } + + fn parse_tool_response(&self, response: &str) -> Option { + let function_regex = Regex::new(r"(.*?)").unwrap(); + println!("Response: {}", response); + if let Some(caps) = function_regex.captures(response) { + let function_name = caps.get(1).unwrap().as_str().to_string(); + let args_string = caps.get(2).unwrap().as_str(); + + match serde_json::from_str(args_string) { + Ok(arguments) => Some(LlamaFunctionCallSignature { + function: function_name, + arguments, + }), + Err(error) => { + println!("Error parsing function arguments: {}", error); + None + } + } + } else { + None + } + } +} + +#[async_trait] +impl RequestParserBase for LlamaFunctionCall { + async fn parse( + &self, + input: &str, + model_name: String, + tools: Vec>, + ) -> Result { + let response_value = self.parse_tool_response(&self.clean_tool_call(input)); + match response_value { + Some(response) => { + if let Some(tool) = tools.iter().find(|t| t.name() == response.function) { + let tool_params = response.arguments; + let result = self + .function_call_with_history( + model_name.clone(), + tool_params.clone(), + tool.clone(), + ) + .await?; + return Ok(result); + } else { + return Err(self.error_handler(OllamaError::from("Tool not found".to_string()))); + } + } + None => { + return Err(self + .error_handler(OllamaError::from("Error parsing function call".to_string()))); + } + } + } + + async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { + let tools_info: Vec = tools.iter().map(convert_to_llama_tool).collect(); + let tools_json = serde_json::to_string(&tools_info).unwrap(); + let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); + ChatMessage::system(system_message_content) + } + + fn error_handler(&self, error: OllamaError) -> ChatMessageResponse { + ChatMessageResponse { + model: "".to_string(), + created_at: "".to_string(), + message: Some(ChatMessage::assistant(error.to_string())), + done: true, + final_data: None, + } + } +} diff --git a/src/generation/functions/pipelines/mod.rs b/src/generation/functions/pipelines/mod.rs index 057e347..aa7fdfb 100644 --- a/src/generation/functions/pipelines/mod.rs +++ b/src/generation/functions/pipelines/mod.rs @@ -4,11 +4,12 @@ use crate::generation::functions::tools::Tool; use async_trait::async_trait; use std::sync::Arc; +pub mod meta_llama; pub mod nous_hermes; pub mod openai; #[async_trait] -pub trait RequestParserBase { +pub trait RequestParserBase: Send + Sync { async fn parse( &self, input: &str, diff --git a/src/generation/functions/tools/browserless.rs b/src/generation/functions/tools/browserless.rs new file mode 100644 index 0000000..411df7f --- /dev/null +++ b/src/generation/functions/tools/browserless.rs @@ -0,0 +1,68 @@ +use reqwest::Client; +use scraper::{Html, Selector}; +use std::env; +use text_splitter::TextSplitter; + +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; +use serde_json::{json, Value}; +use std::error::Error; + +pub struct Browserless {} +//Add headless utilties +#[async_trait] +impl Tool for Browserless { + fn name(&self) -> String { + "browserless_web_scraper".to_string() + } + + fn description(&self) -> String { + "Scrapes text content from websites and splits it into manageable chunks.".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "website": { + "type": "string", + "description": "The URL of the website to scrape" + } + }, + "required": ["website"] + }) + } + + async fn run(&self, input: Value) -> Result> { + let website = input["website"].as_str().ok_or("Website URL is required")?; + let browserless_token = + env::var("BROWSERLESS_TOKEN").expect("BROWSERLESS_TOKEN must be set"); + let url = format!("http://0.0.0.0:3000/content?token={}", browserless_token); + let payload = json!({ + "url": website + }); + let client = Client::new(); + let response = client + .post(&url) + .header("cache-control", "no-cache") + .header("content-type", "application/json") + .json(&payload) + .send() + .await?; + + let response_text = response.text().await?; + let document = Html::parse_document(&response_text); + let selector = Selector::parse("p, h1, h2, h3, h4, h5, h6").unwrap(); + let elements: Vec = document + .select(&selector) + .map(|el| el.text().collect::()) + .collect(); + let body = elements.join(" "); + + let splitter = TextSplitter::new(1000); + let chunks = splitter.chunks(&body); + let sentences: Vec = chunks.map(|s| s.to_string()).collect(); + let sentences = sentences.join("\n \n"); + Ok(sentences) + } +} diff --git a/src/generation/functions/tools/finance.rs b/src/generation/functions/tools/finance.rs index 9161af4..6a7b654 100644 --- a/src/generation/functions/tools/finance.rs +++ b/src/generation/functions/tools/finance.rs @@ -64,7 +64,7 @@ impl StockScraper { #[async_trait] impl Tool for StockScraper { fn name(&self) -> String { - "Stock Scraper".to_string() + "stock_scraper".to_string() } fn description(&self) -> String { diff --git a/src/generation/functions/tools/mod.rs b/src/generation/functions/tools/mod.rs index b5f4223..2d0bd96 100644 --- a/src/generation/functions/tools/mod.rs +++ b/src/generation/functions/tools/mod.rs @@ -1,10 +1,14 @@ +pub mod browserless; pub mod finance; pub mod scraper; pub mod search_ddg; +pub mod serper; +pub use self::browserless::Browserless; pub use self::finance::StockScraper; pub use self::scraper::Scraper; pub use self::search_ddg::DDGSearcher; +pub use self::serper::SerperSearchTool; use async_trait::async_trait; use serde_json::{json, Value}; diff --git a/src/generation/functions/tools/scraper.rs b/src/generation/functions/tools/scraper.rs index 2398e72..9b4fe7f 100644 --- a/src/generation/functions/tools/scraper.rs +++ b/src/generation/functions/tools/scraper.rs @@ -22,7 +22,7 @@ impl Scraper { #[async_trait] impl Tool for Scraper { fn name(&self) -> String { - "Website Scraper".to_string() + "website_scraper".to_string() } fn description(&self) -> String { diff --git a/src/generation/functions/tools/search_ddg.rs b/src/generation/functions/tools/search_ddg.rs index 6793bb2..e31d55d 100644 --- a/src/generation/functions/tools/search_ddg.rs +++ b/src/generation/functions/tools/search_ddg.rs @@ -88,7 +88,7 @@ impl DDGSearcher { #[async_trait] impl Tool for DDGSearcher { fn name(&self) -> String { - "DDG Searcher".to_string() + "ddg_searcher".to_string() } fn description(&self) -> String { diff --git a/src/generation/functions/tools/serper.rs b/src/generation/functions/tools/serper.rs new file mode 100644 index 0000000..19436b5 --- /dev/null +++ b/src/generation/functions/tools/serper.rs @@ -0,0 +1,297 @@ +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::env; +use std::error::Error; + +#[derive(Debug, Deserialize, Serialize)] +pub struct SearchResult { + title: String, + link: String, + snippet: String, + date: String, + position: i32, // -1 indicates missing position +} + +impl SearchResult { + pub fn from_result_data(result_data: &Value) -> Self { + Self { + title: result_data + .get("title") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + link: result_data + .get("link") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + snippet: result_data + .get("snippet") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + date: result_data + .get("date") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + position: result_data + .get("position") + .unwrap_or(&Value::Number(serde_json::Number::from(-1))) + .as_i64() + .unwrap() as i32, + } + } + + pub fn to_formatted_string(&self) -> String { + format!( + "{}\n{}\n{}\n{}\n{}", + self.title, self.link, self.snippet, self.date, self.position + ) + } +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct ScholarResult { + title: String, + link: String, + publication_info: String, + snippet: String, + year: i32, + cited_by: i32, +} + +impl ScholarResult { + pub fn from_result_data(result_data: &Value) -> Self { + Self { + title: result_data + .get("title") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + link: result_data + .get("link") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + publication_info: result_data + .get("publicationInfo") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + snippet: result_data + .get("snippet") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + year: result_data + .get("year") + .unwrap_or(&Value::Number(serde_json::Number::from(-1))) + .as_i64() + .unwrap() as i32, + cited_by: result_data + .get("citedBy") + .unwrap_or(&Value::Number(serde_json::Number::from(-1))) + .as_i64() + .unwrap() as i32, + } + } + + pub fn to_formatted_string(&self) -> String { + format!( + "{}\n{}\n{}\n{}\n{}\n{}", + self.title, self.link, self.publication_info, self.snippet, self.year, self.cited_by + ) + } +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct NewsResult { + title: String, + link: String, + snippet: String, + date: String, + source: String, + image_url: String, + position: i32, // -1 indicates missing position +} + +impl NewsResult { + pub fn from_result_data(result_data: &Value) -> Self { + Self { + title: result_data + .get("title") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + link: result_data + .get("link") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + snippet: result_data + .get("snippet") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + date: result_data + .get("date") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + source: result_data + .get("source") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + image_url: result_data + .get("imageUrl") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + position: result_data + .get("position") + .unwrap_or(&Value::Number(serde_json::Number::from(-1))) + .as_i64() + .unwrap() as i32, + } + } + + pub fn to_formatted_string(&self) -> String { + format!( + "{}\n{}\n{}\n{}\n{}\n{}\n{}", + self.title, + self.link, + self.snippet, + self.date, + self.source, + self.image_url, + self.position + ) + } +} + +pub struct SerperSearchTool; + +#[async_trait] +impl Tool for SerperSearchTool { + fn name(&self) -> String { + "google_search_tool".to_string() + } + + fn description(&self) -> String { + "Conducts a web search using a specified search type and returns the results.".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query" + }, + "lang": { + "type": "string", + "description": "The language for the search" + }, + "n_results": { + "type": "integer", + "description": "The number of results to return" + } + }, + "required": ["query"] + }) + } + /* + "search_type": { + "type": "string", + "description": "The search type (search, scholar, or news)" + } + */ + + async fn run(&self, input: Value) -> Result> { + let query = input["query"].as_str().ok_or("Query is required")?; + let stype = input["search_type"].as_str().unwrap_or("search"); + let lang = input["lang"].as_str().unwrap_or("en"); + let n_result = input["n_results"].as_u64().unwrap_or(5); + + assert!( + ["search", "scholar", "news"].contains(&stype), + "Invalid search type" + ); + + let url = format!("https://google.serper.dev/{}", stype); + let gl = if lang != "en" { lang } else { "us" }; + let n_results = std::cmp::min(n_result, 10); + let mut payload = json!({ + "q": query, + "gl": gl, + "hl": lang, + "page": 1, + "num": n_results + }); + + if stype == "scholar" { + payload.as_object_mut().unwrap().remove("num"); + } + + let client = Client::new(); + let api_key = env::var("SERPER_API_KEY").expect("SERPER_API_KEY must be set"); + let response = client + .post(&url) + .header("X-API-KEY", api_key) + .header("Content-Type", "application/json") + .json(&payload) + .send() + .await? + .json::() + .await?; + + let results = response["organic"] + .as_array() + .ok_or("Invalid response format")?; + let formatted_results = match stype { + "search" => results + .iter() + .take(n_results as usize) + .map(|r| SearchResult::from_result_data(r).to_formatted_string()) + .collect::>(), + "scholar" => results + .iter() + .take(n_results as usize) + .map(|r| ScholarResult::from_result_data(r).to_formatted_string()) + .collect::>(), + "news" => results + .iter() + .take(n_results as usize) + .map(|r| NewsResult::from_result_data(r).to_formatted_string()) + .collect::>(), + _ => { + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Invalid search type", + ))) + } + }; + + Ok(formatted_results.join("\n")) + } +} diff --git a/tests/function_call.rs b/tests/function_call.rs index 034139c..7f965fc 100644 --- a/tests/function_call.rs +++ b/tests/function_call.rs @@ -3,7 +3,9 @@ use ollama_rs::{ generation::chat::ChatMessage, generation::functions::tools::{DDGSearcher, Scraper, StockScraper}, - generation::functions::{FunctionCallRequest, NousFunctionCall}, + generation::functions::{ + FunctionCallRequest, LlamaFunctionCall, NousFunctionCall, OpenAIFunctionCall, + }, Ollama, }; use std::sync::Arc; @@ -13,6 +15,7 @@ async fn test_send_function_call() { /// Model to be used, make sure it is tailored towards "function calling", such as: /// - OpenAIFunctionCall: not model specific, degraded performance /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + /// - LlamaFunctionCall: llama3.1:latest const MODEL: &str = "adrienbrault/nous-hermes2pro:Q8_0"; const PROMPT: &str = "Aside from the Apple Remote, what other device can control the program Apple Remote was originally designed to interact with?"; @@ -43,6 +46,7 @@ async fn test_send_function_call_with_history() { /// Model to be used, make sure it is tailored towards "function calling", such as: /// - OpenAIFunctionCall: not model specific, degraded performance /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + /// - LlamaFunctionCall: llama3.1:latest const MODEL: &str = "adrienbrault/nous-hermes2pro:Q8_0"; const PROMPT: &str = "Aside from the Apple Remote, what other device can control the program Apple Remote was originally designed to interact with?"; @@ -74,6 +78,7 @@ async fn test_send_function_call_finance() { /// Model to be used, make sure it is tailored towards "function calling", such as: /// - OpenAIFunctionCall: not model specific, degraded performance /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + /// - LlamaFunctionCall: llama3.1:latest const MODEL: &str = "adrienbrault/nous-hermes2pro:Q8_0"; const PROMPT: &str = "What are the current risk factors to $APPL?"; @@ -93,3 +98,55 @@ async fn test_send_function_call_finance() { assert!(result.done); } + +#[tokio::test] +async fn test_send_function_call_llama() { + /// Model to be used, make sure it is tailored towards "function calling", such as: + /// - OpenAIFunctionCall: not model specific, degraded performance + /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + /// - LlamaFunctionCall: llama3.1:latest + const MODEL: &str = "llama3.1:latest"; + + const PROMPT: &str = "What are the current risk factors to Apple Inc?"; + let user_message = ChatMessage::user(PROMPT.to_string()); + + let search = Arc::new(DDGSearcher::new()); + let parser = Arc::new(LlamaFunctionCall {}); + + let ollama = Ollama::default(); + let result = ollama + .send_function_call( + FunctionCallRequest::new(MODEL.to_string(), vec![search], vec![user_message]), + parser, + ) + .await + .unwrap(); + + assert!(result.done); +} + +#[tokio::test] +async fn test_send_function_call_phi3_medium() { + /// Model to be used, make sure it is tailored towards "function calling", such as: + /// - OpenAIFunctionCall: not model specific, degraded performance + /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + /// - LlamaFunctionCall: llama3.1:latest + const MODEL: &str = "phi3:14b-medium-4k-instruct-q4_1"; + + const PROMPT: &str = "What are the current risk factors to Apple Inc?"; + let user_message = ChatMessage::user(PROMPT.to_string()); + + let search = Arc::new(DDGSearcher::new()); + let parser = Arc::new(OpenAIFunctionCall {}); + + let ollama = Ollama::default(); + let result = ollama + .send_function_call( + FunctionCallRequest::new(MODEL.to_string(), vec![search], vec![user_message]), + parser, + ) + .await + .unwrap(); + + assert!(result.done); +}