Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3.1 Function Calling + New Tools #59

Merged
merged 5 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/target
.vscode/settings.json
shell.nix
.idea
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ _Returns a `GenerateEmbeddingsResponse` struct containing the embeddings (a vect
### Make a function call

```rust
let tools = vec![Arc::new(Scraper::new())];
let tools = vec![Arc::new(Scraper::new()), Arc::new(DDGSearcher::new())];
let parser = Arc::new(NousFunctionCall::new());
let message = ChatMessage::user("What is the current oil price?".to_string());
let res = ollama.send_function_call(
Expand Down
13 changes: 7 additions & 6 deletions src/generation/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,15 @@ impl Ollama {
/// Without impact for existing history
/// Used to prepare history for request
fn get_chat_messages_by_id(&mut self, history_id: impl ToString) -> Vec<ChatMessage> {
let mut binding = {
let new_history =
std::sync::Arc::new(std::sync::RwLock::new(MessagesHistory::default()));
self.messages_history = Some(new_history);
self.messages_history.clone().unwrap()
};
let chat_history = match self.messages_history.as_mut() {
Some(history) => history,
None => &mut {
let new_history =
std::sync::Arc::new(std::sync::RwLock::new(MessagesHistory::default()));
self.messages_history = Some(new_history);
self.messages_history.clone().unwrap()
},
None => &mut binding,
};
// Clone the current chat messages to avoid borrowing issues
// And not to add message to the history if the request fails
Expand Down
3 changes: 3 additions & 0 deletions src/generation/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ pub mod pipelines;
pub mod request;
pub mod tools;

pub use crate::generation::functions::pipelines::meta_llama::request::LlamaFunctionCall;
pub use crate::generation::functions::pipelines::nous_hermes::request::NousFunctionCall;
pub use crate::generation::functions::pipelines::openai::request::OpenAIFunctionCall;
pub use crate::generation::functions::request::FunctionCallRequest;
pub use tools::Browserless;
pub use tools::DDGSearcher;
pub use tools::Scraper;
pub use tools::SerperSearchTool;
pub use tools::StockScraper;

use crate::error::OllamaError;
Expand Down
4 changes: 4 additions & 0 deletions src/generation/functions/pipelines/meta_llama/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub mod prompts;
pub mod request;

pub use prompts::DEFAULT_SYSTEM_TEMPLATE;
14 changes: 14 additions & 0 deletions src/generation/functions/pipelines/meta_llama/prompts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#"
You have access to the following functions:
{tools}
If you choose to call a function ONLY reply in the following format with no prefix or suffix:

<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>

Reminder:
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
"#;
136 changes: 136 additions & 0 deletions src/generation/functions/pipelines/meta_llama/request.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use crate::error::OllamaError;
use crate::generation::chat::{ChatMessage, ChatMessageResponse};
use crate::generation::functions::pipelines::meta_llama::DEFAULT_SYSTEM_TEMPLATE;
use crate::generation::functions::pipelines::RequestParserBase;
use crate::generation::functions::tools::Tool;
use async_trait::async_trait;
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;

pub fn convert_to_llama_tool(tool: &Arc<dyn Tool>) -> Value {
let mut function = HashMap::new();
function.insert("name".to_string(), Value::String(tool.name()));
function.insert("description".to_string(), Value::String(tool.description()));
function.insert("parameters".to_string(), tool.parameters());
json!(format!(
"Use the function '{name}' to '{description}': {json}",
name = tool.name(),
description = tool.description(),
json = json!(function)
))
}

#[derive(Debug, Deserialize, Serialize)]
pub struct LlamaFunctionCallSignature {
pub function: String, //name of the tool
pub arguments: Value,
}

pub struct LlamaFunctionCall {}

impl LlamaFunctionCall {
pub async fn function_call_with_history(
&self,
model_name: String,
tool_params: Value,
tool: Arc<dyn Tool>,
) -> Result<ChatMessageResponse, ChatMessageResponse> {
let result = tool.run(tool_params).await;
match result {
Ok(result) => Ok(ChatMessageResponse {
model: model_name.clone(),
created_at: "".to_string(),
message: Some(ChatMessage::assistant(result.to_string())),
done: true,
final_data: None,
}),
Err(e) => Err(self.error_handler(OllamaError::from(e))),
}
}

fn clean_tool_call(&self, json_str: &str) -> String {
json_str
.trim()
.trim_start_matches("```json")
.trim_end_matches("```")
.trim()
.to_string()
.replace("{{", "{")
.replace("}}", "}")
}

fn parse_tool_response(&self, response: &str) -> Option<LlamaFunctionCallSignature> {
let function_regex = Regex::new(r"<function=(\w+)>(.*?)</function>").unwrap();
println!("Response: {}", response);
if let Some(caps) = function_regex.captures(response) {
let function_name = caps.get(1).unwrap().as_str().to_string();
let args_string = caps.get(2).unwrap().as_str();

match serde_json::from_str(args_string) {
Ok(arguments) => Some(LlamaFunctionCallSignature {
function: function_name,
arguments,
}),
Err(error) => {
println!("Error parsing function arguments: {}", error);
None
}
}
} else {
None
}
}
}

#[async_trait]
impl RequestParserBase for LlamaFunctionCall {
async fn parse(
&self,
input: &str,
model_name: String,
tools: Vec<Arc<dyn Tool>>,
) -> Result<ChatMessageResponse, ChatMessageResponse> {
let response_value = self.parse_tool_response(&self.clean_tool_call(input));
match response_value {
Some(response) => {
if let Some(tool) = tools.iter().find(|t| t.name() == response.function) {
let tool_params = response.arguments;
let result = self
.function_call_with_history(
model_name.clone(),
tool_params.clone(),
tool.clone(),
)
.await?;
return Ok(result);
} else {
return Err(self.error_handler(OllamaError::from("Tool not found".to_string())));
}
}
None => {
return Err(self
.error_handler(OllamaError::from("Error parsing function call".to_string())));
}
}
}

async fn get_system_message(&self, tools: &[Arc<dyn Tool>]) -> ChatMessage {
let tools_info: Vec<Value> = tools.iter().map(convert_to_llama_tool).collect();
let tools_json = serde_json::to_string(&tools_info).unwrap();
let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json);
ChatMessage::system(system_message_content)
}

fn error_handler(&self, error: OllamaError) -> ChatMessageResponse {
ChatMessageResponse {
model: "".to_string(),
created_at: "".to_string(),
message: Some(ChatMessage::assistant(error.to_string())),
done: true,
final_data: None,
}
}
}
3 changes: 2 additions & 1 deletion src/generation/functions/pipelines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use crate::generation::functions::tools::Tool;
use async_trait::async_trait;
use std::sync::Arc;

pub mod meta_llama;
pub mod nous_hermes;
pub mod openai;

#[async_trait]
pub trait RequestParserBase {
pub trait RequestParserBase: Send + Sync {
async fn parse(
&self,
input: &str,
Expand Down
68 changes: 68 additions & 0 deletions src/generation/functions/tools/browserless.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use reqwest::Client;
use scraper::{Html, Selector};
use std::env;
use text_splitter::TextSplitter;

use crate::generation::functions::tools::Tool;
use async_trait::async_trait;
use serde_json::{json, Value};
use std::error::Error;

pub struct Browserless {}
//Add headless utilties
#[async_trait]
impl Tool for Browserless {
fn name(&self) -> String {
"browserless_web_scraper".to_string()
}

fn description(&self) -> String {
"Scrapes text content from websites and splits it into manageable chunks.".to_string()
}

fn parameters(&self) -> Value {
json!({
"type": "object",
"properties": {
"website": {
"type": "string",
"description": "The URL of the website to scrape"
}
},
"required": ["website"]
})
}

async fn run(&self, input: Value) -> Result<String, Box<dyn Error>> {
let website = input["website"].as_str().ok_or("Website URL is required")?;
let browserless_token =
env::var("BROWSERLESS_TOKEN").expect("BROWSERLESS_TOKEN must be set");
let url = format!("http://0.0.0.0:3000/content?token={}", browserless_token);
let payload = json!({
"url": website
});
let client = Client::new();
let response = client
.post(&url)
.header("cache-control", "no-cache")
.header("content-type", "application/json")
.json(&payload)
.send()
.await?;

let response_text = response.text().await?;
let document = Html::parse_document(&response_text);
let selector = Selector::parse("p, h1, h2, h3, h4, h5, h6").unwrap();
let elements: Vec<String> = document
.select(&selector)
.map(|el| el.text().collect::<String>())
.collect();
let body = elements.join(" ");

let splitter = TextSplitter::new(1000);
let chunks = splitter.chunks(&body);
let sentences: Vec<String> = chunks.map(|s| s.to_string()).collect();
let sentences = sentences.join("\n \n");
Ok(sentences)
}
}
2 changes: 1 addition & 1 deletion src/generation/functions/tools/finance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl StockScraper {
#[async_trait]
impl Tool for StockScraper {
fn name(&self) -> String {
"Stock Scraper".to_string()
"stock_scraper".to_string()
}

fn description(&self) -> String {
Expand Down
4 changes: 4 additions & 0 deletions src/generation/functions/tools/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
pub mod browserless;
pub mod finance;
pub mod scraper;
pub mod search_ddg;
pub mod serper;

pub use self::browserless::Browserless;
pub use self::finance::StockScraper;
pub use self::scraper::Scraper;
pub use self::search_ddg::DDGSearcher;
pub use self::serper::SerperSearchTool;

use async_trait::async_trait;
use serde_json::{json, Value};
Expand Down
2 changes: 1 addition & 1 deletion src/generation/functions/tools/scraper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Scraper {
#[async_trait]
impl Tool for Scraper {
fn name(&self) -> String {
"Website Scraper".to_string()
"website_scraper".to_string()
}

fn description(&self) -> String {
Expand Down
2 changes: 1 addition & 1 deletion src/generation/functions/tools/search_ddg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl DDGSearcher {
#[async_trait]
impl Tool for DDGSearcher {
fn name(&self) -> String {
"DDG Searcher".to_string()
"ddg_searcher".to_string()
}

fn description(&self) -> String {
Expand Down
Loading
Loading