Skip to content

Commit

Permalink
Merge pull request #59 from andthattoo/master
Browse files Browse the repository at this point in the history
Llama3.1 Function Calling + New Tools
  • Loading branch information
pepperoni21 authored Aug 3, 2024
2 parents c5f6928 + dcdfc07 commit d014cda
Show file tree
Hide file tree
Showing 15 changed files with 598 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/target
.vscode/settings.json
shell.nix
.idea
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ _Returns a `GenerateEmbeddingsResponse` struct containing the embeddings (a vect
### Make a function call

```rust
let tools = vec![Arc::new(Scraper::new())];
let tools = vec![Arc::new(Scraper::new()), Arc::new(DDGSearcher::new())];
let parser = Arc::new(NousFunctionCall::new());
let message = ChatMessage::user("What is the current oil price?".to_string());
let res = ollama.send_function_call(
Expand Down
13 changes: 7 additions & 6 deletions src/generation/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,15 @@ impl Ollama {
/// Without impact for existing history
/// Used to prepare history for request
fn get_chat_messages_by_id(&mut self, history_id: impl ToString) -> Vec<ChatMessage> {
let mut binding = {
let new_history =
std::sync::Arc::new(std::sync::RwLock::new(MessagesHistory::default()));
self.messages_history = Some(new_history);
self.messages_history.clone().unwrap()
};
let chat_history = match self.messages_history.as_mut() {
Some(history) => history,
None => &mut {
let new_history =
std::sync::Arc::new(std::sync::RwLock::new(MessagesHistory::default()));
self.messages_history = Some(new_history);
self.messages_history.clone().unwrap()
},
None => &mut binding,
};
// Clone the current chat messages to avoid borrowing issues
// And not to add message to the history if the request fails
Expand Down
3 changes: 3 additions & 0 deletions src/generation/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ pub mod pipelines;
pub mod request;
pub mod tools;

pub use crate::generation::functions::pipelines::meta_llama::request::LlamaFunctionCall;
pub use crate::generation::functions::pipelines::nous_hermes::request::NousFunctionCall;
pub use crate::generation::functions::pipelines::openai::request::OpenAIFunctionCall;
pub use crate::generation::functions::request::FunctionCallRequest;
pub use tools::Browserless;
pub use tools::DDGSearcher;
pub use tools::Scraper;
pub use tools::SerperSearchTool;
pub use tools::StockScraper;

use crate::error::OllamaError;
Expand Down
4 changes: 4 additions & 0 deletions src/generation/functions/pipelines/meta_llama/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub mod prompts;
pub mod request;

pub use prompts::DEFAULT_SYSTEM_TEMPLATE;
14 changes: 14 additions & 0 deletions src/generation/functions/pipelines/meta_llama/prompts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#"
You have access to the following functions:
{tools}
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>
Reminder:
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
"#;
136 changes: 136 additions & 0 deletions src/generation/functions/pipelines/meta_llama/request.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use crate::error::OllamaError;
use crate::generation::chat::{ChatMessage, ChatMessageResponse};
use crate::generation::functions::pipelines::meta_llama::DEFAULT_SYSTEM_TEMPLATE;
use crate::generation::functions::pipelines::RequestParserBase;
use crate::generation::functions::tools::Tool;
use async_trait::async_trait;
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;

pub fn convert_to_llama_tool(tool: &Arc<dyn Tool>) -> Value {
let mut function = HashMap::new();
function.insert("name".to_string(), Value::String(tool.name()));
function.insert("description".to_string(), Value::String(tool.description()));
function.insert("parameters".to_string(), tool.parameters());
json!(format!(
"Use the function '{name}' to '{description}': {json}",
name = tool.name(),
description = tool.description(),
json = json!(function)
))
}

#[derive(Debug, Deserialize, Serialize)]
pub struct LlamaFunctionCallSignature {
pub function: String, //name of the tool
pub arguments: Value,
}

pub struct LlamaFunctionCall {}

impl LlamaFunctionCall {
pub async fn function_call_with_history(
&self,
model_name: String,
tool_params: Value,
tool: Arc<dyn Tool>,
) -> Result<ChatMessageResponse, ChatMessageResponse> {
let result = tool.run(tool_params).await;
match result {
Ok(result) => Ok(ChatMessageResponse {
model: model_name.clone(),
created_at: "".to_string(),
message: Some(ChatMessage::assistant(result.to_string())),
done: true,
final_data: None,
}),
Err(e) => Err(self.error_handler(OllamaError::from(e))),
}
}

fn clean_tool_call(&self, json_str: &str) -> String {
json_str
.trim()
.trim_start_matches("```json")
.trim_end_matches("```")
.trim()
.to_string()
.replace("{{", "{")
.replace("}}", "}")
}

fn parse_tool_response(&self, response: &str) -> Option<LlamaFunctionCallSignature> {
let function_regex = Regex::new(r"<function=(\w+)>(.*?)</function>").unwrap();
println!("Response: {}", response);
if let Some(caps) = function_regex.captures(response) {
let function_name = caps.get(1).unwrap().as_str().to_string();
let args_string = caps.get(2).unwrap().as_str();

match serde_json::from_str(args_string) {
Ok(arguments) => Some(LlamaFunctionCallSignature {
function: function_name,
arguments,
}),
Err(error) => {
println!("Error parsing function arguments: {}", error);
None
}
}
} else {
None
}
}
}

#[async_trait]
impl RequestParserBase for LlamaFunctionCall {
async fn parse(
&self,
input: &str,
model_name: String,
tools: Vec<Arc<dyn Tool>>,
) -> Result<ChatMessageResponse, ChatMessageResponse> {
let response_value = self.parse_tool_response(&self.clean_tool_call(input));
match response_value {
Some(response) => {
if let Some(tool) = tools.iter().find(|t| t.name() == response.function) {
let tool_params = response.arguments;
let result = self
.function_call_with_history(
model_name.clone(),
tool_params.clone(),
tool.clone(),
)
.await?;
return Ok(result);
} else {
return Err(self.error_handler(OllamaError::from("Tool not found".to_string())));
}
}
None => {
return Err(self
.error_handler(OllamaError::from("Error parsing function call".to_string())));
}
}
}

async fn get_system_message(&self, tools: &[Arc<dyn Tool>]) -> ChatMessage {
let tools_info: Vec<Value> = tools.iter().map(convert_to_llama_tool).collect();
let tools_json = serde_json::to_string(&tools_info).unwrap();
let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json);
ChatMessage::system(system_message_content)
}

fn error_handler(&self, error: OllamaError) -> ChatMessageResponse {
ChatMessageResponse {
model: "".to_string(),
created_at: "".to_string(),
message: Some(ChatMessage::assistant(error.to_string())),
done: true,
final_data: None,
}
}
}
3 changes: 2 additions & 1 deletion src/generation/functions/pipelines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use crate::generation::functions::tools::Tool;
use async_trait::async_trait;
use std::sync::Arc;

pub mod meta_llama;
pub mod nous_hermes;
pub mod openai;

#[async_trait]
pub trait RequestParserBase {
pub trait RequestParserBase: Send + Sync {
async fn parse(
&self,
input: &str,
Expand Down
68 changes: 68 additions & 0 deletions src/generation/functions/tools/browserless.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use reqwest::Client;
use scraper::{Html, Selector};
use std::env;
use text_splitter::TextSplitter;

use crate::generation::functions::tools::Tool;
use async_trait::async_trait;
use serde_json::{json, Value};
use std::error::Error;

pub struct Browserless {}
//Add headless utilties
#[async_trait]
impl Tool for Browserless {
fn name(&self) -> String {
"browserless_web_scraper".to_string()
}

fn description(&self) -> String {
"Scrapes text content from websites and splits it into manageable chunks.".to_string()
}

fn parameters(&self) -> Value {
json!({
"type": "object",
"properties": {
"website": {
"type": "string",
"description": "The URL of the website to scrape"
}
},
"required": ["website"]
})
}

async fn run(&self, input: Value) -> Result<String, Box<dyn Error>> {
let website = input["website"].as_str().ok_or("Website URL is required")?;
let browserless_token =
env::var("BROWSERLESS_TOKEN").expect("BROWSERLESS_TOKEN must be set");
let url = format!("http://0.0.0.0:3000/content?token={}", browserless_token);
let payload = json!({
"url": website
});
let client = Client::new();
let response = client
.post(&url)
.header("cache-control", "no-cache")
.header("content-type", "application/json")
.json(&payload)
.send()
.await?;

let response_text = response.text().await?;
let document = Html::parse_document(&response_text);
let selector = Selector::parse("p, h1, h2, h3, h4, h5, h6").unwrap();
let elements: Vec<String> = document
.select(&selector)
.map(|el| el.text().collect::<String>())
.collect();
let body = elements.join(" ");

let splitter = TextSplitter::new(1000);
let chunks = splitter.chunks(&body);
let sentences: Vec<String> = chunks.map(|s| s.to_string()).collect();
let sentences = sentences.join("\n \n");
Ok(sentences)
}
}
2 changes: 1 addition & 1 deletion src/generation/functions/tools/finance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl StockScraper {
#[async_trait]
impl Tool for StockScraper {
fn name(&self) -> String {
"Stock Scraper".to_string()
"stock_scraper".to_string()
}

fn description(&self) -> String {
Expand Down
4 changes: 4 additions & 0 deletions src/generation/functions/tools/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
pub mod browserless;
pub mod finance;
pub mod scraper;
pub mod search_ddg;
pub mod serper;

pub use self::browserless::Browserless;
pub use self::finance::StockScraper;
pub use self::scraper::Scraper;
pub use self::search_ddg::DDGSearcher;
pub use self::serper::SerperSearchTool;

use async_trait::async_trait;
use serde_json::{json, Value};
Expand Down
2 changes: 1 addition & 1 deletion src/generation/functions/tools/scraper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Scraper {
#[async_trait]
impl Tool for Scraper {
fn name(&self) -> String {
"Website Scraper".to_string()
"website_scraper".to_string()
}

fn description(&self) -> String {
Expand Down
2 changes: 1 addition & 1 deletion src/generation/functions/tools/search_ddg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl DDGSearcher {
#[async_trait]
impl Tool for DDGSearcher {
fn name(&self) -> String {
"DDG Searcher".to_string()
"ddg_searcher".to_string()
}

fn description(&self) -> String {
Expand Down
Loading

0 comments on commit d014cda

Please sign in to comment.