-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #59 from andthattoo/master
Llama3.1 Function Calling + New Tools
- Loading branch information
Showing
15 changed files
with
598 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
/target | ||
.vscode/settings.json | ||
shell.nix | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
pub mod prompts; | ||
pub mod request; | ||
|
||
pub use prompts::DEFAULT_SYSTEM_TEMPLATE; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#" | ||
You have access to the following functions: | ||
{tools} | ||
If you choose to call a function ONLY reply in the following format with no prefix or suffix: | ||
<function=example_function_name>{{\"example_name\": \"example_value\"}}</function> | ||
Reminder: | ||
- Function calls MUST follow the specified format, start with <function= and end with </function> | ||
- Required parameters MUST be specified | ||
- Only call one function at a time | ||
- Put the entire function call reply on one line | ||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls | ||
"#; |
136 changes: 136 additions & 0 deletions
136
src/generation/functions/pipelines/meta_llama/request.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
use crate::error::OllamaError; | ||
use crate::generation::chat::{ChatMessage, ChatMessageResponse}; | ||
use crate::generation::functions::pipelines::meta_llama::DEFAULT_SYSTEM_TEMPLATE; | ||
use crate::generation::functions::pipelines::RequestParserBase; | ||
use crate::generation::functions::tools::Tool; | ||
use async_trait::async_trait; | ||
use regex::Regex; | ||
use serde::{Deserialize, Serialize}; | ||
use serde_json::{json, Value}; | ||
use std::collections::HashMap; | ||
use std::sync::Arc; | ||
|
||
pub fn convert_to_llama_tool(tool: &Arc<dyn Tool>) -> Value { | ||
let mut function = HashMap::new(); | ||
function.insert("name".to_string(), Value::String(tool.name())); | ||
function.insert("description".to_string(), Value::String(tool.description())); | ||
function.insert("parameters".to_string(), tool.parameters()); | ||
json!(format!( | ||
"Use the function '{name}' to '{description}': {json}", | ||
name = tool.name(), | ||
description = tool.description(), | ||
json = json!(function) | ||
)) | ||
} | ||
|
||
#[derive(Debug, Deserialize, Serialize)] | ||
pub struct LlamaFunctionCallSignature { | ||
pub function: String, //name of the tool | ||
pub arguments: Value, | ||
} | ||
|
||
pub struct LlamaFunctionCall {} | ||
|
||
impl LlamaFunctionCall { | ||
pub async fn function_call_with_history( | ||
&self, | ||
model_name: String, | ||
tool_params: Value, | ||
tool: Arc<dyn Tool>, | ||
) -> Result<ChatMessageResponse, ChatMessageResponse> { | ||
let result = tool.run(tool_params).await; | ||
match result { | ||
Ok(result) => Ok(ChatMessageResponse { | ||
model: model_name.clone(), | ||
created_at: "".to_string(), | ||
message: Some(ChatMessage::assistant(result.to_string())), | ||
done: true, | ||
final_data: None, | ||
}), | ||
Err(e) => Err(self.error_handler(OllamaError::from(e))), | ||
} | ||
} | ||
|
||
fn clean_tool_call(&self, json_str: &str) -> String { | ||
json_str | ||
.trim() | ||
.trim_start_matches("```json") | ||
.trim_end_matches("```") | ||
.trim() | ||
.to_string() | ||
.replace("{{", "{") | ||
.replace("}}", "}") | ||
} | ||
|
||
fn parse_tool_response(&self, response: &str) -> Option<LlamaFunctionCallSignature> { | ||
let function_regex = Regex::new(r"<function=(\w+)>(.*?)</function>").unwrap(); | ||
println!("Response: {}", response); | ||
if let Some(caps) = function_regex.captures(response) { | ||
let function_name = caps.get(1).unwrap().as_str().to_string(); | ||
let args_string = caps.get(2).unwrap().as_str(); | ||
|
||
match serde_json::from_str(args_string) { | ||
Ok(arguments) => Some(LlamaFunctionCallSignature { | ||
function: function_name, | ||
arguments, | ||
}), | ||
Err(error) => { | ||
println!("Error parsing function arguments: {}", error); | ||
None | ||
} | ||
} | ||
} else { | ||
None | ||
} | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl RequestParserBase for LlamaFunctionCall { | ||
async fn parse( | ||
&self, | ||
input: &str, | ||
model_name: String, | ||
tools: Vec<Arc<dyn Tool>>, | ||
) -> Result<ChatMessageResponse, ChatMessageResponse> { | ||
let response_value = self.parse_tool_response(&self.clean_tool_call(input)); | ||
match response_value { | ||
Some(response) => { | ||
if let Some(tool) = tools.iter().find(|t| t.name() == response.function) { | ||
let tool_params = response.arguments; | ||
let result = self | ||
.function_call_with_history( | ||
model_name.clone(), | ||
tool_params.clone(), | ||
tool.clone(), | ||
) | ||
.await?; | ||
return Ok(result); | ||
} else { | ||
return Err(self.error_handler(OllamaError::from("Tool not found".to_string()))); | ||
} | ||
} | ||
None => { | ||
return Err(self | ||
.error_handler(OllamaError::from("Error parsing function call".to_string()))); | ||
} | ||
} | ||
} | ||
|
||
async fn get_system_message(&self, tools: &[Arc<dyn Tool>]) -> ChatMessage { | ||
let tools_info: Vec<Value> = tools.iter().map(convert_to_llama_tool).collect(); | ||
let tools_json = serde_json::to_string(&tools_info).unwrap(); | ||
let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); | ||
ChatMessage::system(system_message_content) | ||
} | ||
|
||
fn error_handler(&self, error: OllamaError) -> ChatMessageResponse { | ||
ChatMessageResponse { | ||
model: "".to_string(), | ||
created_at: "".to_string(), | ||
message: Some(ChatMessage::assistant(error.to_string())), | ||
done: true, | ||
final_data: None, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
use reqwest::Client; | ||
use scraper::{Html, Selector}; | ||
use std::env; | ||
use text_splitter::TextSplitter; | ||
|
||
use crate::generation::functions::tools::Tool; | ||
use async_trait::async_trait; | ||
use serde_json::{json, Value}; | ||
use std::error::Error; | ||
|
||
pub struct Browserless {} | ||
//Add headless utilties | ||
#[async_trait] | ||
impl Tool for Browserless { | ||
fn name(&self) -> String { | ||
"browserless_web_scraper".to_string() | ||
} | ||
|
||
fn description(&self) -> String { | ||
"Scrapes text content from websites and splits it into manageable chunks.".to_string() | ||
} | ||
|
||
fn parameters(&self) -> Value { | ||
json!({ | ||
"type": "object", | ||
"properties": { | ||
"website": { | ||
"type": "string", | ||
"description": "The URL of the website to scrape" | ||
} | ||
}, | ||
"required": ["website"] | ||
}) | ||
} | ||
|
||
async fn run(&self, input: Value) -> Result<String, Box<dyn Error>> { | ||
let website = input["website"].as_str().ok_or("Website URL is required")?; | ||
let browserless_token = | ||
env::var("BROWSERLESS_TOKEN").expect("BROWSERLESS_TOKEN must be set"); | ||
let url = format!("http://0.0.0.0:3000/content?token={}", browserless_token); | ||
let payload = json!({ | ||
"url": website | ||
}); | ||
let client = Client::new(); | ||
let response = client | ||
.post(&url) | ||
.header("cache-control", "no-cache") | ||
.header("content-type", "application/json") | ||
.json(&payload) | ||
.send() | ||
.await?; | ||
|
||
let response_text = response.text().await?; | ||
let document = Html::parse_document(&response_text); | ||
let selector = Selector::parse("p, h1, h2, h3, h4, h5, h6").unwrap(); | ||
let elements: Vec<String> = document | ||
.select(&selector) | ||
.map(|el| el.text().collect::<String>()) | ||
.collect(); | ||
let body = elements.join(" "); | ||
|
||
let splitter = TextSplitter::new(1000); | ||
let chunks = splitter.chunks(&body); | ||
let sentences: Vec<String> = chunks.map(|s| s.to_string()).collect(); | ||
let sentences = sentences.join("\n \n"); | ||
Ok(sentences) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.