diff --git a/.gitignore b/.gitignore index 3e9f11a..61dfbb6 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ .vscode/settings.json shell.nix .idea +.DS_Store \ No newline at end of file diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs index 2e761c2..52d83b9 100644 --- a/src/generation/functions/mod.rs +++ b/src/generation/functions/mod.rs @@ -59,6 +59,10 @@ impl crate::Ollama { ) .await?; + if request.raw_mode { + return Ok(tool_call_result); + } + let tool_call_content: String = tool_call_result.message.clone().unwrap().content; let result = parser .parse( @@ -96,8 +100,12 @@ impl crate::Ollama { request.chat.messages.insert(0, system_prompt); } let result = self.send_chat_messages(request.chat).await?; - let response_content: String = result.message.clone().unwrap().content; + if request.raw_mode { + return Ok(result); + } + + let response_content: String = result.message.clone().unwrap().content; let result = parser .parse(&response_content, model_name, request.tools) .await; diff --git a/src/generation/functions/pipelines/meta_llama/request.rs b/src/generation/functions/pipelines/meta_llama/request.rs index 8ed2ebd..cd841b4 100644 --- a/src/generation/functions/pipelines/meta_llama/request.rs +++ b/src/generation/functions/pipelines/meta_llama/request.rs @@ -62,26 +62,28 @@ impl LlamaFunctionCall { .replace("}}", "}") } - fn parse_tool_response(&self, response: &str) -> Option { + fn parse_tool_response(&self, response: &str) -> Vec { let function_regex = Regex::new(r"(.*?)").unwrap(); println!("Response: {}", response); - if let Some(caps) = function_regex.captures(response) { - let function_name = caps.get(1).unwrap().as_str().to_string(); - let args_string = caps.get(2).unwrap().as_str(); - - match serde_json::from_str(args_string) { - Ok(arguments) => Some(LlamaFunctionCallSignature { - function: function_name, - arguments, - }), - Err(error) => { - println!("Error parsing function arguments: {}", error); - None + + function_regex + .captures_iter(response) + .filter_map(|caps| { + let function_name = caps.get(1)?.as_str().to_string(); + let args_string = caps.get(2)?.as_str(); + + match serde_json::from_str(args_string) { + Ok(arguments) => Some(LlamaFunctionCallSignature { + function: function_name, + arguments, + }), + Err(error) => { + println!("Error parsing function arguments: {}", error); + None + } } - } - } else { - None - } + }) + .collect() } } @@ -93,28 +95,47 @@ impl RequestParserBase for LlamaFunctionCall { model_name: String, tools: Vec>, ) -> Result { - let response_value = self.parse_tool_response(&self.clean_tool_call(input)); - match response_value { - Some(response) => { - if let Some(tool) = tools.iter().find(|t| t.name() == response.function) { - let tool_params = response.arguments; - let result = self - .function_call_with_history( - model_name.clone(), - tool_params.clone(), - tool.clone(), - ) - .await?; - return Ok(result); - } else { - return Err(self.error_handler(OllamaError::from("Tool not found".to_string()))); + let function_calls = self.parse_tool_response(&self.clean_tool_call(input)); + + if function_calls.is_empty() { + return Err(self.error_handler(OllamaError::from( + "No valid function calls found".to_string(), + ))); + } + + let mut results = Vec::new(); + + for call in function_calls { + if let Some(tool) = tools.iter().find(|t| t.name() == call.function) { + let tool_params = call.arguments; + match self + .function_call_with_history(model_name.clone(), tool_params, tool.clone()) + .await + { + Ok(result) => results.push(result), + Err(e) => results.push(e), } - } - None => { - return Err(self - .error_handler(OllamaError::from("Error parsing function call".to_string()))); + } else { + results.push(self.error_handler(OllamaError::from(format!( + "Tool '{}' not found", + call.function + )))); } } + + let combined_message = results + .into_iter() + .map(|r| r.message.map_or_else(String::new, |m| m.content)) + .collect::>() + .join("\n\n"); + + Ok(ChatMessageResponse { + model: model_name, + created_at: "".to_string(), + message: Some(ChatMessage::assistant(combined_message)), + done: true, + final_data: None, + }) } async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { diff --git a/src/generation/functions/request.rs b/src/generation/functions/request.rs index 44da42e..d2cb306 100644 --- a/src/generation/functions/request.rs +++ b/src/generation/functions/request.rs @@ -8,12 +8,17 @@ use std::sync::Arc; pub struct FunctionCallRequest { pub chat: ChatMessageRequest, pub tools: Vec>, + pub raw_mode: bool, } impl FunctionCallRequest { pub fn new(model_name: String, tools: Vec>, messages: Vec) -> Self { let chat = ChatMessageRequest::new(model_name, messages); - Self { chat, tools } + Self { + chat, + tools, + raw_mode: false, + } } /// Additional model parameters listed in the documentation for the Modelfile @@ -33,4 +38,9 @@ impl FunctionCallRequest { self.chat.format = Some(format); self } + + pub fn raw_mode(mut self) -> Self { + self.raw_mode = true; + self + } } diff --git a/tests/function_call.rs b/tests/function_call.rs index 7f965fc..5511086 100644 --- a/tests/function_call.rs +++ b/tests/function_call.rs @@ -47,14 +47,14 @@ async fn test_send_function_call_with_history() { /// - OpenAIFunctionCall: not model specific, degraded performance /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 /// - LlamaFunctionCall: llama3.1:latest - const MODEL: &str = "adrienbrault/nous-hermes2pro:Q8_0"; + const MODEL: &str = "phi3:14b-medium-4k-instruct-q4_1"; const PROMPT: &str = "Aside from the Apple Remote, what other device can control the program Apple Remote was originally designed to interact with?"; let user_message = ChatMessage::user(PROMPT.to_string()); let scraper_tool = Arc::new(Scraper::new()); let ddg_search_tool = Arc::new(DDGSearcher::new()); - let parser = Arc::new(NousFunctionCall::new()); + let parser = Arc::new(OpenAIFunctionCall {}); let mut ollama = Ollama::new_default_with_history(30); let result = ollama @@ -126,27 +126,27 @@ async fn test_send_function_call_llama() { } #[tokio::test] -async fn test_send_function_call_phi3_medium() { +async fn test_send_function_call_llama_raw() { /// Model to be used, make sure it is tailored towards "function calling", such as: /// - OpenAIFunctionCall: not model specific, degraded performance /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 /// - LlamaFunctionCall: llama3.1:latest - const MODEL: &str = "phi3:14b-medium-4k-instruct-q4_1"; + const MODEL: &str = "llama3.1:latest"; const PROMPT: &str = "What are the current risk factors to Apple Inc?"; let user_message = ChatMessage::user(PROMPT.to_string()); let search = Arc::new(DDGSearcher::new()); - let parser = Arc::new(OpenAIFunctionCall {}); + let parser = Arc::new(LlamaFunctionCall {}); let ollama = Ollama::default(); let result = ollama .send_function_call( - FunctionCallRequest::new(MODEL.to_string(), vec![search], vec![user_message]), + FunctionCallRequest::new(MODEL.to_string(), vec![search], vec![user_message]) + .raw_mode(), parser, ) .await .unwrap(); - assert!(result.done); }