From 426b6d698a58a737f145054cd12555658b5e5e4b Mon Sep 17 00:00:00 2001 From: andthattoo Date: Fri, 14 Jun 2024 10:43:18 +0300 Subject: [PATCH 1/9] added sedn + sync to request parser base --- README.md | 2 +- src/generation/functions/pipelines/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f55e3c0..fbff271 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ _Returns a `GenerateEmbeddingsResponse` struct containing the embeddings (a vect ### Make a function call ```rust -let tools = vec![Arc::new(Scraper::new())]; +let tools = vec![Arc::new(Scraper::new()), Arc::new(DDGSearcher::new())]; let parser = Arc::new(NousFunctionCall::new()); let message = ChatMessage::user("What is the current oil price?".to_string()); let res = ollama.send_function_call( diff --git a/src/generation/functions/pipelines/mod.rs b/src/generation/functions/pipelines/mod.rs index 057e347..6d205e5 100644 --- a/src/generation/functions/pipelines/mod.rs +++ b/src/generation/functions/pipelines/mod.rs @@ -8,7 +8,7 @@ pub mod nous_hermes; pub mod openai; #[async_trait] -pub trait RequestParserBase { +pub trait RequestParserBase: Send + Sync { async fn parse( &self, input: &str, From 1b6a5f04cdb25380ccf1b66c9451aa151390f205 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 30 Jul 2024 11:09:51 +0300 Subject: [PATCH 2/9] Added meta-llama function calling pipeline for llama3.1 models tool names are now underscored added few tests added serper and browserless as builtin tools --- .gitignore | 1 + src/generation/functions/mod.rs | 3 + .../functions/pipelines/meta_llama/mod.rs | 4 + .../functions/pipelines/meta_llama/prompts.rs | 14 + .../functions/pipelines/meta_llama/request.rs | 136 ++++++++ src/generation/functions/pipelines/mod.rs | 1 + src/generation/functions/tools/browserless.rs | 68 ++++ src/generation/functions/tools/finance.rs | 2 +- src/generation/functions/tools/mod.rs | 4 + src/generation/functions/tools/scraper.rs | 2 +- src/generation/functions/tools/search_ddg.rs | 2 +- src/generation/functions/tools/serper.rs | 297 ++++++++++++++++++ tests/function_call.rs | 59 +++- 13 files changed, 589 insertions(+), 4 deletions(-) create mode 100644 src/generation/functions/pipelines/meta_llama/mod.rs create mode 100644 src/generation/functions/pipelines/meta_llama/prompts.rs create mode 100644 src/generation/functions/pipelines/meta_llama/request.rs create mode 100644 src/generation/functions/tools/browserless.rs create mode 100644 src/generation/functions/tools/serper.rs diff --git a/.gitignore b/.gitignore index 83cdd23..3e9f11a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target .vscode/settings.json shell.nix +.idea diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs index 56f2bb9..0f1839d 100644 --- a/src/generation/functions/mod.rs +++ b/src/generation/functions/mod.rs @@ -2,11 +2,14 @@ pub mod pipelines; pub mod request; pub mod tools; +pub use crate::generation::functions::pipelines::meta_llama::request::LlamaFunctionCall; pub use crate::generation::functions::pipelines::nous_hermes::request::NousFunctionCall; pub use crate::generation::functions::pipelines::openai::request::OpenAIFunctionCall; pub use crate::generation::functions::request::FunctionCallRequest; +pub use tools::Browserless; pub use tools::DDGSearcher; pub use tools::Scraper; +pub use tools::SerperSearchTool; pub use tools::StockScraper; use crate::error::OllamaError; diff --git a/src/generation/functions/pipelines/meta_llama/mod.rs b/src/generation/functions/pipelines/meta_llama/mod.rs new file mode 100644 index 0000000..23c05af --- /dev/null +++ b/src/generation/functions/pipelines/meta_llama/mod.rs @@ -0,0 +1,4 @@ +pub mod prompts; +pub mod request; + +pub use prompts::DEFAULT_SYSTEM_TEMPLATE; diff --git a/src/generation/functions/pipelines/meta_llama/prompts.rs b/src/generation/functions/pipelines/meta_llama/prompts.rs new file mode 100644 index 0000000..fcb1a6d --- /dev/null +++ b/src/generation/functions/pipelines/meta_llama/prompts.rs @@ -0,0 +1,14 @@ +pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#" +You have access to the following functions: +{tools} +If you choose to call a function ONLY reply in the following format with no prefix or suffix: + +{{\"example_name\": \"example_value\"}} + +Reminder: +- Function calls MUST follow the specified format, start with +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +"#; diff --git a/src/generation/functions/pipelines/meta_llama/request.rs b/src/generation/functions/pipelines/meta_llama/request.rs new file mode 100644 index 0000000..8ed2ebd --- /dev/null +++ b/src/generation/functions/pipelines/meta_llama/request.rs @@ -0,0 +1,136 @@ +use crate::error::OllamaError; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; +use crate::generation::functions::pipelines::meta_llama::DEFAULT_SYSTEM_TEMPLATE; +use crate::generation::functions::pipelines::RequestParserBase; +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; +use regex::Regex; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::sync::Arc; + +pub fn convert_to_llama_tool(tool: &Arc) -> Value { + let mut function = HashMap::new(); + function.insert("name".to_string(), Value::String(tool.name())); + function.insert("description".to_string(), Value::String(tool.description())); + function.insert("parameters".to_string(), tool.parameters()); + json!(format!( + "Use the function '{name}' to '{description}': {json}", + name = tool.name(), + description = tool.description(), + json = json!(function) + )) +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct LlamaFunctionCallSignature { + pub function: String, //name of the tool + pub arguments: Value, +} + +pub struct LlamaFunctionCall {} + +impl LlamaFunctionCall { + pub async fn function_call_with_history( + &self, + model_name: String, + tool_params: Value, + tool: Arc, + ) -> Result { + let result = tool.run(tool_params).await; + match result { + Ok(result) => Ok(ChatMessageResponse { + model: model_name.clone(), + created_at: "".to_string(), + message: Some(ChatMessage::assistant(result.to_string())), + done: true, + final_data: None, + }), + Err(e) => Err(self.error_handler(OllamaError::from(e))), + } + } + + fn clean_tool_call(&self, json_str: &str) -> String { + json_str + .trim() + .trim_start_matches("```json") + .trim_end_matches("```") + .trim() + .to_string() + .replace("{{", "{") + .replace("}}", "}") + } + + fn parse_tool_response(&self, response: &str) -> Option { + let function_regex = Regex::new(r"(.*?)").unwrap(); + println!("Response: {}", response); + if let Some(caps) = function_regex.captures(response) { + let function_name = caps.get(1).unwrap().as_str().to_string(); + let args_string = caps.get(2).unwrap().as_str(); + + match serde_json::from_str(args_string) { + Ok(arguments) => Some(LlamaFunctionCallSignature { + function: function_name, + arguments, + }), + Err(error) => { + println!("Error parsing function arguments: {}", error); + None + } + } + } else { + None + } + } +} + +#[async_trait] +impl RequestParserBase for LlamaFunctionCall { + async fn parse( + &self, + input: &str, + model_name: String, + tools: Vec>, + ) -> Result { + let response_value = self.parse_tool_response(&self.clean_tool_call(input)); + match response_value { + Some(response) => { + if let Some(tool) = tools.iter().find(|t| t.name() == response.function) { + let tool_params = response.arguments; + let result = self + .function_call_with_history( + model_name.clone(), + tool_params.clone(), + tool.clone(), + ) + .await?; + return Ok(result); + } else { + return Err(self.error_handler(OllamaError::from("Tool not found".to_string()))); + } + } + None => { + return Err(self + .error_handler(OllamaError::from("Error parsing function call".to_string()))); + } + } + } + + async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { + let tools_info: Vec = tools.iter().map(convert_to_llama_tool).collect(); + let tools_json = serde_json::to_string(&tools_info).unwrap(); + let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); + ChatMessage::system(system_message_content) + } + + fn error_handler(&self, error: OllamaError) -> ChatMessageResponse { + ChatMessageResponse { + model: "".to_string(), + created_at: "".to_string(), + message: Some(ChatMessage::assistant(error.to_string())), + done: true, + final_data: None, + } + } +} diff --git a/src/generation/functions/pipelines/mod.rs b/src/generation/functions/pipelines/mod.rs index 6d205e5..aa7fdfb 100644 --- a/src/generation/functions/pipelines/mod.rs +++ b/src/generation/functions/pipelines/mod.rs @@ -4,6 +4,7 @@ use crate::generation::functions::tools::Tool; use async_trait::async_trait; use std::sync::Arc; +pub mod meta_llama; pub mod nous_hermes; pub mod openai; diff --git a/src/generation/functions/tools/browserless.rs b/src/generation/functions/tools/browserless.rs new file mode 100644 index 0000000..411df7f --- /dev/null +++ b/src/generation/functions/tools/browserless.rs @@ -0,0 +1,68 @@ +use reqwest::Client; +use scraper::{Html, Selector}; +use std::env; +use text_splitter::TextSplitter; + +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; +use serde_json::{json, Value}; +use std::error::Error; + +pub struct Browserless {} +//Add headless utilties +#[async_trait] +impl Tool for Browserless { + fn name(&self) -> String { + "browserless_web_scraper".to_string() + } + + fn description(&self) -> String { + "Scrapes text content from websites and splits it into manageable chunks.".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "website": { + "type": "string", + "description": "The URL of the website to scrape" + } + }, + "required": ["website"] + }) + } + + async fn run(&self, input: Value) -> Result> { + let website = input["website"].as_str().ok_or("Website URL is required")?; + let browserless_token = + env::var("BROWSERLESS_TOKEN").expect("BROWSERLESS_TOKEN must be set"); + let url = format!("http://0.0.0.0:3000/content?token={}", browserless_token); + let payload = json!({ + "url": website + }); + let client = Client::new(); + let response = client + .post(&url) + .header("cache-control", "no-cache") + .header("content-type", "application/json") + .json(&payload) + .send() + .await?; + + let response_text = response.text().await?; + let document = Html::parse_document(&response_text); + let selector = Selector::parse("p, h1, h2, h3, h4, h5, h6").unwrap(); + let elements: Vec = document + .select(&selector) + .map(|el| el.text().collect::()) + .collect(); + let body = elements.join(" "); + + let splitter = TextSplitter::new(1000); + let chunks = splitter.chunks(&body); + let sentences: Vec = chunks.map(|s| s.to_string()).collect(); + let sentences = sentences.join("\n \n"); + Ok(sentences) + } +} diff --git a/src/generation/functions/tools/finance.rs b/src/generation/functions/tools/finance.rs index 9161af4..6a7b654 100644 --- a/src/generation/functions/tools/finance.rs +++ b/src/generation/functions/tools/finance.rs @@ -64,7 +64,7 @@ impl StockScraper { #[async_trait] impl Tool for StockScraper { fn name(&self) -> String { - "Stock Scraper".to_string() + "stock_scraper".to_string() } fn description(&self) -> String { diff --git a/src/generation/functions/tools/mod.rs b/src/generation/functions/tools/mod.rs index b5f4223..2d0bd96 100644 --- a/src/generation/functions/tools/mod.rs +++ b/src/generation/functions/tools/mod.rs @@ -1,10 +1,14 @@ +pub mod browserless; pub mod finance; pub mod scraper; pub mod search_ddg; +pub mod serper; +pub use self::browserless::Browserless; pub use self::finance::StockScraper; pub use self::scraper::Scraper; pub use self::search_ddg::DDGSearcher; +pub use self::serper::SerperSearchTool; use async_trait::async_trait; use serde_json::{json, Value}; diff --git a/src/generation/functions/tools/scraper.rs b/src/generation/functions/tools/scraper.rs index 2398e72..9b4fe7f 100644 --- a/src/generation/functions/tools/scraper.rs +++ b/src/generation/functions/tools/scraper.rs @@ -22,7 +22,7 @@ impl Scraper { #[async_trait] impl Tool for Scraper { fn name(&self) -> String { - "Website Scraper".to_string() + "website_scraper".to_string() } fn description(&self) -> String { diff --git a/src/generation/functions/tools/search_ddg.rs b/src/generation/functions/tools/search_ddg.rs index 6793bb2..e31d55d 100644 --- a/src/generation/functions/tools/search_ddg.rs +++ b/src/generation/functions/tools/search_ddg.rs @@ -88,7 +88,7 @@ impl DDGSearcher { #[async_trait] impl Tool for DDGSearcher { fn name(&self) -> String { - "DDG Searcher".to_string() + "ddg_searcher".to_string() } fn description(&self) -> String { diff --git a/src/generation/functions/tools/serper.rs b/src/generation/functions/tools/serper.rs new file mode 100644 index 0000000..19436b5 --- /dev/null +++ b/src/generation/functions/tools/serper.rs @@ -0,0 +1,297 @@ +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::env; +use std::error::Error; + +#[derive(Debug, Deserialize, Serialize)] +pub struct SearchResult { + title: String, + link: String, + snippet: String, + date: String, + position: i32, // -1 indicates missing position +} + +impl SearchResult { + pub fn from_result_data(result_data: &Value) -> Self { + Self { + title: result_data + .get("title") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + link: result_data + .get("link") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + snippet: result_data + .get("snippet") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + date: result_data + .get("date") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + position: result_data + .get("position") + .unwrap_or(&Value::Number(serde_json::Number::from(-1))) + .as_i64() + .unwrap() as i32, + } + } + + pub fn to_formatted_string(&self) -> String { + format!( + "{}\n{}\n{}\n{}\n{}", + self.title, self.link, self.snippet, self.date, self.position + ) + } +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct ScholarResult { + title: String, + link: String, + publication_info: String, + snippet: String, + year: i32, + cited_by: i32, +} + +impl ScholarResult { + pub fn from_result_data(result_data: &Value) -> Self { + Self { + title: result_data + .get("title") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + link: result_data + .get("link") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + publication_info: result_data + .get("publicationInfo") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + snippet: result_data + .get("snippet") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + year: result_data + .get("year") + .unwrap_or(&Value::Number(serde_json::Number::from(-1))) + .as_i64() + .unwrap() as i32, + cited_by: result_data + .get("citedBy") + .unwrap_or(&Value::Number(serde_json::Number::from(-1))) + .as_i64() + .unwrap() as i32, + } + } + + pub fn to_formatted_string(&self) -> String { + format!( + "{}\n{}\n{}\n{}\n{}\n{}", + self.title, self.link, self.publication_info, self.snippet, self.year, self.cited_by + ) + } +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct NewsResult { + title: String, + link: String, + snippet: String, + date: String, + source: String, + image_url: String, + position: i32, // -1 indicates missing position +} + +impl NewsResult { + pub fn from_result_data(result_data: &Value) -> Self { + Self { + title: result_data + .get("title") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + link: result_data + .get("link") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + snippet: result_data + .get("snippet") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + date: result_data + .get("date") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + source: result_data + .get("source") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + image_url: result_data + .get("imageUrl") + .unwrap_or(&Value::String("none".to_string())) + .as_str() + .unwrap() + .to_string(), + position: result_data + .get("position") + .unwrap_or(&Value::Number(serde_json::Number::from(-1))) + .as_i64() + .unwrap() as i32, + } + } + + pub fn to_formatted_string(&self) -> String { + format!( + "{}\n{}\n{}\n{}\n{}\n{}\n{}", + self.title, + self.link, + self.snippet, + self.date, + self.source, + self.image_url, + self.position + ) + } +} + +pub struct SerperSearchTool; + +#[async_trait] +impl Tool for SerperSearchTool { + fn name(&self) -> String { + "google_search_tool".to_string() + } + + fn description(&self) -> String { + "Conducts a web search using a specified search type and returns the results.".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query" + }, + "lang": { + "type": "string", + "description": "The language for the search" + }, + "n_results": { + "type": "integer", + "description": "The number of results to return" + } + }, + "required": ["query"] + }) + } + /* + "search_type": { + "type": "string", + "description": "The search type (search, scholar, or news)" + } + */ + + async fn run(&self, input: Value) -> Result> { + let query = input["query"].as_str().ok_or("Query is required")?; + let stype = input["search_type"].as_str().unwrap_or("search"); + let lang = input["lang"].as_str().unwrap_or("en"); + let n_result = input["n_results"].as_u64().unwrap_or(5); + + assert!( + ["search", "scholar", "news"].contains(&stype), + "Invalid search type" + ); + + let url = format!("https://google.serper.dev/{}", stype); + let gl = if lang != "en" { lang } else { "us" }; + let n_results = std::cmp::min(n_result, 10); + let mut payload = json!({ + "q": query, + "gl": gl, + "hl": lang, + "page": 1, + "num": n_results + }); + + if stype == "scholar" { + payload.as_object_mut().unwrap().remove("num"); + } + + let client = Client::new(); + let api_key = env::var("SERPER_API_KEY").expect("SERPER_API_KEY must be set"); + let response = client + .post(&url) + .header("X-API-KEY", api_key) + .header("Content-Type", "application/json") + .json(&payload) + .send() + .await? + .json::() + .await?; + + let results = response["organic"] + .as_array() + .ok_or("Invalid response format")?; + let formatted_results = match stype { + "search" => results + .iter() + .take(n_results as usize) + .map(|r| SearchResult::from_result_data(r).to_formatted_string()) + .collect::>(), + "scholar" => results + .iter() + .take(n_results as usize) + .map(|r| ScholarResult::from_result_data(r).to_formatted_string()) + .collect::>(), + "news" => results + .iter() + .take(n_results as usize) + .map(|r| NewsResult::from_result_data(r).to_formatted_string()) + .collect::>(), + _ => { + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Invalid search type", + ))) + } + }; + + Ok(formatted_results.join("\n")) + } +} diff --git a/tests/function_call.rs b/tests/function_call.rs index 034139c..7f965fc 100644 --- a/tests/function_call.rs +++ b/tests/function_call.rs @@ -3,7 +3,9 @@ use ollama_rs::{ generation::chat::ChatMessage, generation::functions::tools::{DDGSearcher, Scraper, StockScraper}, - generation::functions::{FunctionCallRequest, NousFunctionCall}, + generation::functions::{ + FunctionCallRequest, LlamaFunctionCall, NousFunctionCall, OpenAIFunctionCall, + }, Ollama, }; use std::sync::Arc; @@ -13,6 +15,7 @@ async fn test_send_function_call() { /// Model to be used, make sure it is tailored towards "function calling", such as: /// - OpenAIFunctionCall: not model specific, degraded performance /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + /// - LlamaFunctionCall: llama3.1:latest const MODEL: &str = "adrienbrault/nous-hermes2pro:Q8_0"; const PROMPT: &str = "Aside from the Apple Remote, what other device can control the program Apple Remote was originally designed to interact with?"; @@ -43,6 +46,7 @@ async fn test_send_function_call_with_history() { /// Model to be used, make sure it is tailored towards "function calling", such as: /// - OpenAIFunctionCall: not model specific, degraded performance /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + /// - LlamaFunctionCall: llama3.1:latest const MODEL: &str = "adrienbrault/nous-hermes2pro:Q8_0"; const PROMPT: &str = "Aside from the Apple Remote, what other device can control the program Apple Remote was originally designed to interact with?"; @@ -74,6 +78,7 @@ async fn test_send_function_call_finance() { /// Model to be used, make sure it is tailored towards "function calling", such as: /// - OpenAIFunctionCall: not model specific, degraded performance /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + /// - LlamaFunctionCall: llama3.1:latest const MODEL: &str = "adrienbrault/nous-hermes2pro:Q8_0"; const PROMPT: &str = "What are the current risk factors to $APPL?"; @@ -93,3 +98,55 @@ async fn test_send_function_call_finance() { assert!(result.done); } + +#[tokio::test] +async fn test_send_function_call_llama() { + /// Model to be used, make sure it is tailored towards "function calling", such as: + /// - OpenAIFunctionCall: not model specific, degraded performance + /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + /// - LlamaFunctionCall: llama3.1:latest + const MODEL: &str = "llama3.1:latest"; + + const PROMPT: &str = "What are the current risk factors to Apple Inc?"; + let user_message = ChatMessage::user(PROMPT.to_string()); + + let search = Arc::new(DDGSearcher::new()); + let parser = Arc::new(LlamaFunctionCall {}); + + let ollama = Ollama::default(); + let result = ollama + .send_function_call( + FunctionCallRequest::new(MODEL.to_string(), vec![search], vec![user_message]), + parser, + ) + .await + .unwrap(); + + assert!(result.done); +} + +#[tokio::test] +async fn test_send_function_call_phi3_medium() { + /// Model to be used, make sure it is tailored towards "function calling", such as: + /// - OpenAIFunctionCall: not model specific, degraded performance + /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + /// - LlamaFunctionCall: llama3.1:latest + const MODEL: &str = "phi3:14b-medium-4k-instruct-q4_1"; + + const PROMPT: &str = "What are the current risk factors to Apple Inc?"; + let user_message = ChatMessage::user(PROMPT.to_string()); + + let search = Arc::new(DDGSearcher::new()); + let parser = Arc::new(OpenAIFunctionCall {}); + + let ollama = Ollama::default(); + let result = ollama + .send_function_call( + FunctionCallRequest::new(MODEL.to_string(), vec![search], vec![user_message]), + parser, + ) + .await + .unwrap(); + + assert!(result.done); +} From f793366697c69af83c6b9bbf35aca0d2bb94d381 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 30 Jul 2024 12:24:35 +0300 Subject: [PATCH 3/9] get_chat_messages_by_id had a borrow issue, fixed it --- src/generation/chat/mod.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/generation/chat/mod.rs b/src/generation/chat/mod.rs index 4bf0239..9b399be 100644 --- a/src/generation/chat/mod.rs +++ b/src/generation/chat/mod.rs @@ -196,14 +196,15 @@ impl Ollama { /// Without impact for existing history /// Used to prepare history for request fn get_chat_messages_by_id(&mut self, history_id: impl ToString) -> Vec { + let mut binding = { + let new_history = + std::sync::Arc::new(std::sync::RwLock::new(MessagesHistory::default())); + self.messages_history = Some(new_history); + self.messages_history.clone().unwrap() + }; let chat_history = match self.messages_history.as_mut() { Some(history) => history, - None => &mut { - let new_history = - std::sync::Arc::new(std::sync::RwLock::new(MessagesHistory::default())); - self.messages_history = Some(new_history); - self.messages_history.clone().unwrap() - }, + None => &mut binding, }; // Clone the current chat messages to avoid borrowing issues // And not to add message to the history if the request fails From dcdfc07a0974f414b3a86c9aed17d43f1bba7bd1 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 30 Jul 2024 12:25:05 +0300 Subject: [PATCH 4/9] fmt fix --- src/generation/chat/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/generation/chat/mod.rs b/src/generation/chat/mod.rs index 9b399be..7fdceb0 100644 --- a/src/generation/chat/mod.rs +++ b/src/generation/chat/mod.rs @@ -201,7 +201,7 @@ impl Ollama { std::sync::Arc::new(std::sync::RwLock::new(MessagesHistory::default())); self.messages_history = Some(new_history); self.messages_history.clone().unwrap() - }; + }; let chat_history = match self.messages_history.as_mut() { Some(history) => history, None => &mut binding, From b994be08ea8a9b94e6b0f8da02c4c8c72fa9dc09 Mon Sep 17 00:00:00 2001 From: pepperoni21 Date: Sat, 3 Aug 2024 14:47:50 +0200 Subject: [PATCH 5/9] Updated embeddings generation to /api/embed endpoint --- src/generation/embeddings.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/generation/embeddings.rs b/src/generation/embeddings.rs index f06eeb9..b6e1e3b 100644 --- a/src/generation/embeddings.rs +++ b/src/generation/embeddings.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; use crate::Ollama; -use super::options::GenerationOptions; +use super::{options::GenerationOptions, parameters::KeepAlive}; impl Ollama { /// Generate embeddings from a model @@ -16,11 +16,12 @@ impl Ollama { ) -> crate::error::Result { let request = GenerateEmbeddingsRequest { model_name, - prompt, + input: prompt, options, + ..Default::default() }; - let url = format!("{}api/embeddings", self.url_str()); + let url = format!("{}api/embed", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self .reqwest_client @@ -43,18 +44,19 @@ impl Ollama { } /// An embeddings generation request to Ollama. -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Default)] struct GenerateEmbeddingsRequest { #[serde(rename = "model")] model_name: String, - prompt: String, + input: String, + truncate: Option, options: Option, + keep_alive: Option, } /// An embeddings generation response from Ollama. #[derive(Debug, Deserialize, Clone)] pub struct GenerateEmbeddingsResponse { - #[serde(rename = "embedding")] #[allow(dead_code)] - pub embeddings: Vec, + pub embeddings: Vec>, } From 9c7cef5ac8c16585e9c4d7ba6eefe8849c0bde48 Mon Sep 17 00:00:00 2001 From: pepperoni21 Date: Sat, 3 Aug 2024 15:03:39 +0200 Subject: [PATCH 6/9] Separated request into the request module and added support for multiple inputs --- .../{embeddings.rs => embeddings/mod.rs} | 28 ++----- src/generation/embeddings/request.rs | 78 +++++++++++++++++++ tests/embeddings_generation.rs | 9 ++- 3 files changed, 88 insertions(+), 27 deletions(-) rename src/generation/{embeddings.rs => embeddings/mod.rs} (63%) create mode 100644 src/generation/embeddings/request.rs diff --git a/src/generation/embeddings.rs b/src/generation/embeddings/mod.rs similarity index 63% rename from src/generation/embeddings.rs rename to src/generation/embeddings/mod.rs index b6e1e3b..6c10b1a 100644 --- a/src/generation/embeddings.rs +++ b/src/generation/embeddings/mod.rs @@ -1,8 +1,10 @@ -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use crate::Ollama; -use super::{options::GenerationOptions, parameters::KeepAlive}; +use self::request::GenerateEmbeddingsRequest; + +pub mod request; impl Ollama { /// Generate embeddings from a model @@ -10,17 +12,8 @@ impl Ollama { /// * `prompt` - Prompt to generate embeddings for pub async fn generate_embeddings( &self, - model_name: String, - prompt: String, - options: Option, + request: GenerateEmbeddingsRequest, ) -> crate::error::Result { - let request = GenerateEmbeddingsRequest { - model_name, - input: prompt, - options, - ..Default::default() - }; - let url = format!("{}api/embed", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self @@ -43,17 +36,6 @@ impl Ollama { } } -/// An embeddings generation request to Ollama. -#[derive(Debug, Serialize, Default)] -struct GenerateEmbeddingsRequest { - #[serde(rename = "model")] - model_name: String, - input: String, - truncate: Option, - options: Option, - keep_alive: Option, -} - /// An embeddings generation response from Ollama. #[derive(Debug, Deserialize, Clone)] pub struct GenerateEmbeddingsResponse { diff --git a/src/generation/embeddings/request.rs b/src/generation/embeddings/request.rs new file mode 100644 index 0000000..1de4f38 --- /dev/null +++ b/src/generation/embeddings/request.rs @@ -0,0 +1,78 @@ +use serde::{Serialize, Serializer}; + +use crate::generation::{options::GenerationOptions, parameters::KeepAlive}; + +#[derive(Debug)] +pub enum EmbeddingsInput { + Single(String), + Multiple(Vec), +} + +impl Default for EmbeddingsInput { + fn default() -> Self { + Self::Single(String::default()) + } +} + +impl From for EmbeddingsInput { + fn from(s: String) -> Self { + Self::Single(s) + } +} + +impl From<&str> for EmbeddingsInput { + fn from(s: &str) -> Self { + Self::Single(s.to_string()) + } +} + +impl From> for EmbeddingsInput { + fn from(v: Vec) -> Self { + Self::Multiple(v) + } +} + +impl Serialize for EmbeddingsInput { + fn serialize(&self, serializer: S) -> Result { + match self { + EmbeddingsInput::Single(s) => s.serialize(serializer), + EmbeddingsInput::Multiple(v) => v.serialize(serializer), + } + } +} + +/// An embeddings generation request to Ollama. +#[derive(Debug, Serialize, Default)] +pub struct GenerateEmbeddingsRequest { + #[serde(rename = "model")] + model_name: String, + input: EmbeddingsInput, + truncate: Option, + options: Option, + keep_alive: Option, +} + +impl GenerateEmbeddingsRequest { + pub fn new(model_name: String, input: EmbeddingsInput) -> Self { + Self { + model_name, + input, + ..Default::default() + } + } + + pub fn options(mut self, options: GenerationOptions) -> Self { + self.options = Some(options); + self + } + + pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self { + self.keep_alive = Some(keep_alive); + self + } + + pub fn truncate(mut self, truncate: bool) -> Self { + self.truncate = Some(truncate); + self + } +} diff --git a/tests/embeddings_generation.rs b/tests/embeddings_generation.rs index d71546a..242e7ee 100644 --- a/tests/embeddings_generation.rs +++ b/tests/embeddings_generation.rs @@ -1,13 +1,14 @@ -use ollama_rs::Ollama; +use ollama_rs::{generation::embeddings::request::GenerateEmbeddingsRequest, Ollama}; #[tokio::test] async fn test_embeddings_generation() { let ollama = Ollama::default(); - let prompt = "Why is the sky blue?".to_string(); - let res = ollama - .generate_embeddings("llama2:latest".to_string(), prompt, None) + .generate_embeddings(GenerateEmbeddingsRequest::new( + "llama2:latest".to_string(), + "Why is the sky blue".into(), + )) .await .unwrap(); From 3443e963797c4f996e63e9608797e36da8e9695d Mon Sep 17 00:00:00 2001 From: pepperoni21 Date: Sat, 3 Aug 2024 15:15:51 +0200 Subject: [PATCH 7/9] Fixed batch embeddings and added test --- src/generation/embeddings/request.rs | 6 ++++++ tests/embeddings_generation.rs | 15 +++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/generation/embeddings/request.rs b/src/generation/embeddings/request.rs index 1de4f38..e44a0c5 100644 --- a/src/generation/embeddings/request.rs +++ b/src/generation/embeddings/request.rs @@ -32,6 +32,12 @@ impl From> for EmbeddingsInput { } } +impl From> for EmbeddingsInput { + fn from(v: Vec<&str>) -> Self { + Self::Multiple(v.iter().map(|s| s.to_string()).collect()) + } +} + impl Serialize for EmbeddingsInput { fn serialize(&self, serializer: S) -> Result { match self { diff --git a/tests/embeddings_generation.rs b/tests/embeddings_generation.rs index 242e7ee..ffe7b62 100644 --- a/tests/embeddings_generation.rs +++ b/tests/embeddings_generation.rs @@ -14,3 +14,18 @@ async fn test_embeddings_generation() { dbg!(res); } + +#[tokio::test] +async fn test_batch_embeddings_generation() { + let ollama = Ollama::default(); + + let res = ollama + .generate_embeddings(GenerateEmbeddingsRequest::new( + "llama2:latest".to_string(), + vec!["Why is the sky blue?", "Why is the sky red?"].into(), + )) + .await + .unwrap(); + + dbg!(res); +} From ad1c824c0e913d804a0521c6ac920a288d208e1e Mon Sep 17 00:00:00 2001 From: heydocode <169077757+heydocode@users.noreply.github.com> Date: Sat, 17 Aug 2024 13:53:45 +0200 Subject: [PATCH 8/9] Create images_to_ollama.rs --- examples/images_to_ollama.rs | 61 ++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 examples/images_to_ollama.rs diff --git a/examples/images_to_ollama.rs b/examples/images_to_ollama.rs new file mode 100644 index 0000000..2a1f542 --- /dev/null +++ b/examples/images_to_ollama.rs @@ -0,0 +1,61 @@ +use base64::Engine; +use ollama_rs::{ + generation::{ + completion::{request::GenerationRequest, GenerationResponse}, + images::Image, + }, + Ollama, +}; +use reqwest::get; +use tokio::runtime::Runtime; + +const IMAGE_URL: &str = "https://images.pexels.com/photos/1054655/pexels-photo-1054655.jpeg"; +const PROMPT: &str = "Describe this image"; + +fn main() { + let rt = Runtime::new().unwrap(); + rt.block_on(async { + // Download the image and encode it to base64 + let bytes = match download_image(IMAGE_URL).await { + Ok(b) => b, + Err(e) => { + eprintln!("Failed to download image: {}", e); + return; + }, + }; + let base64_image = base64::engine::general_purpose::STANDARD.encode(&bytes); + + // Create an Image struct from the base64 string + let image = Image::from_base64(&base64_image); + + // Create a GenerationRequest with the model and prompt, adding the image + let request = GenerationRequest::new("llava:latest".to_string(), PROMPT.to_string()) + .add_image(image); + + // Send the request to the model and get the response + let response = match send_request(request).await { + Ok(r) => r, + Err(e) => { + eprintln!("Failed to get response: {}", e); + return; + }, + }; + + // Print the response + println!("{}", response.response); + }); +} + +// Function to download the image +async fn download_image(url: &str) -> Result, reqwest::Error> { + let response = get(url).await?; + let bytes = response.bytes().await?; + Ok(bytes.to_vec()) +} + +// Function to send the request to the model +async fn send_request(request: GenerationRequest) -> Result> { + let ollama = Ollama::default(); + let response = ollama.generate(request).await?; + Ok(response) +} From ff46963b7ec22f52724d0fe69197f66b54346711 Mon Sep 17 00:00:00 2001 From: pepperoni21 Date: Sat, 17 Aug 2024 14:03:26 +0200 Subject: [PATCH 9/9] Fixed formatting --- examples/images_to_ollama.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/images_to_ollama.rs b/examples/images_to_ollama.rs index 2a1f542..2b9203f 100644 --- a/examples/images_to_ollama.rs +++ b/examples/images_to_ollama.rs @@ -21,7 +21,7 @@ fn main() { Err(e) => { eprintln!("Failed to download image: {}", e); return; - }, + } }; let base64_image = base64::engine::general_purpose::STANDARD.encode(&bytes); @@ -29,8 +29,8 @@ fn main() { let image = Image::from_base64(&base64_image); // Create a GenerationRequest with the model and prompt, adding the image - let request = GenerationRequest::new("llava:latest".to_string(), PROMPT.to_string()) - .add_image(image); + let request = + GenerationRequest::new("llava:latest".to_string(), PROMPT.to_string()).add_image(image); // Send the request to the model and get the response let response = match send_request(request).await { @@ -38,7 +38,7 @@ fn main() { Err(e) => { eprintln!("Failed to get response: {}", e); return; - }, + } }; // Print the response @@ -54,7 +54,9 @@ async fn download_image(url: &str) -> Result, reqwest::Error> { } // Function to send the request to the model -async fn send_request(request: GenerationRequest) -> Result> { +async fn send_request( + request: GenerationRequest, +) -> Result> { let ollama = Ollama::default(); let response = ollama.generate(request).await?; Ok(response)