diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs index 73b1206..967a634 100644 --- a/src/generation/functions/mod.rs +++ b/src/generation/functions/mod.rs @@ -65,7 +65,7 @@ impl crate::Ollama { return Ok(tool_call_result); } - let tool_call_content: String = tool_call_result.message.clone().unwrap().content; + let tool_call_content: String = tool_call_result.message.unwrap().content; let result = parser .parse( &tool_call_content, @@ -80,8 +80,8 @@ impl crate::Ollama { Ok(r) } Err(e) => { - self.add_assistant_response(id.clone(), e.message.clone().unwrap().content); - Err(OllamaError::from(e.message.unwrap().content)) + self.add_assistant_response(id.clone(), e.message.clone()); + Err(e) } } } @@ -108,12 +108,8 @@ impl crate::Ollama { } let response_content: String = result.message.clone().unwrap().content; - let result = parser + return parser .parse(&response_content, model_name, request.tools) .await; - match result { - Ok(r) => Ok(r), - Err(e) => Err(OllamaError::from(e.message.unwrap().content)), - } } } diff --git a/src/generation/functions/pipelines/meta_llama/request.rs b/src/generation/functions/pipelines/meta_llama/request.rs index fa18a0c..270276d 100644 --- a/src/generation/functions/pipelines/meta_llama/request.rs +++ b/src/generation/functions/pipelines/meta_llama/request.rs @@ -37,7 +37,7 @@ impl LlamaFunctionCall { model_name: String, tool_params: Value, tool: Arc, - ) -> Result { + ) -> Result { let result = tool.run(tool_params).await; match result { Ok(result) => Ok(ChatMessageResponse { @@ -47,7 +47,7 @@ impl LlamaFunctionCall { done: true, final_data: None, }), - Err(e) => Err(self.error_handler(OllamaError::from(e))), + Err(e) => Err(OllamaError::from(e)), } } @@ -95,13 +95,13 @@ impl RequestParserBase for LlamaFunctionCall { input: &str, model_name: String, tools: Vec>, - ) -> Result { + ) -> Result { let function_calls = self.parse_tool_response(&self.clean_tool_call(input)); if function_calls.is_empty() { - return Err(self.error_handler(OllamaError::from( + return Err(OllamaError::from( "No valid function calls found".to_string(), - ))); + )); } let mut results = Vec::new(); @@ -109,18 +109,15 @@ impl RequestParserBase for LlamaFunctionCall { for call in function_calls { if let Some(tool) = tools.iter().find(|t| t.name() == call.function) { let tool_params = call.arguments; - match self + let result = self .function_call_with_history(model_name.clone(), tool_params, tool.clone()) - .await - { - Ok(result) => results.push(result), - Err(e) => results.push(e), - } + .await?; + results.push(result); } else { - results.push(self.error_handler(OllamaError::from(format!( + return Err(OllamaError::from(format!( "Tool '{}' not found", call.function - )))); + ))); } } @@ -145,14 +142,4 @@ impl RequestParserBase for LlamaFunctionCall { let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); ChatMessage::system(system_message_content) } - - fn error_handler(&self, error: OllamaError) -> ChatMessageResponse { - ChatMessageResponse { - model: "".to_string(), - created_at: "".to_string(), - message: Some(ChatMessage::assistant(error.to_string())), - done: true, - final_data: None, - } - } } diff --git a/src/generation/functions/pipelines/mod.rs b/src/generation/functions/pipelines/mod.rs index aa7fdfb..6fc3ce7 100644 --- a/src/generation/functions/pipelines/mod.rs +++ b/src/generation/functions/pipelines/mod.rs @@ -15,7 +15,7 @@ pub trait RequestParserBase: Send + Sync { input: &str, model_name: String, tools: Vec>, - ) -> Result; + ) -> Result; fn format_query(&self, input: &str) -> String { input.to_string() } @@ -23,5 +23,4 @@ pub trait RequestParserBase: Send + Sync { response.to_string() } async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage; - fn error_handler(&self, error: OllamaError) -> ChatMessageResponse; } diff --git a/src/generation/functions/pipelines/nous_hermes/request.rs b/src/generation/functions/pipelines/nous_hermes/request.rs index f1c30a8..c6d6cb7 100644 --- a/src/generation/functions/pipelines/nous_hermes/request.rs +++ b/src/generation/functions/pipelines/nous_hermes/request.rs @@ -49,7 +49,7 @@ impl NousFunctionCall { model_name: String, tool_params: Value, tool: Arc, - ) -> Result { + ) -> Result { let result = tool.run(tool_params).await; match result { Ok(result) => Ok(ChatMessageResponse { @@ -59,7 +59,7 @@ impl NousFunctionCall { done: true, final_data: None, }), - Err(e) => Err(self.error_handler(OllamaError::from(e))), + Err(e) => Err(OllamaError::from(e)), } } @@ -90,7 +90,7 @@ impl RequestParserBase for NousFunctionCall { input: &str, model_name: String, tools: Vec>, - ) -> Result { + ) -> Result { //Extract between and let tool_response = self.extract_tool_call(input); match tool_response { @@ -110,18 +110,16 @@ impl RequestParserBase for NousFunctionCall { .await?; //Error is also returned as String for LLM feedback return Ok(result); } else { - return Err(self.error_handler(OllamaError::from( - "Tool name not found".to_string(), - ))); + return Err(OllamaError::from("Tool name not found".to_string())); } } - Err(e) => return Err(self.error_handler(OllamaError::from(e))), + Err(e) => return Err(OllamaError::from(e)), } } None => { - return Err(self.error_handler(OllamaError::from( + return Err(OllamaError::from( "Error while extracting tags.".to_string(), - ))) + )) } } } @@ -143,19 +141,4 @@ impl RequestParserBase for NousFunctionCall { let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); ChatMessage::system(system_message_content) } - - fn error_handler(&self, error: OllamaError) -> ChatMessageResponse { - let error_message = format!( - "\nThere was an error parsing function calls\n Here's the error stack trace: {}\nPlease call the function again with correct syntax", - error - ); - - ChatMessageResponse { - model: "".to_string(), - created_at: "".to_string(), - message: Some(ChatMessage::assistant(error_message)), - done: true, - final_data: None, - } - } } diff --git a/src/generation/functions/pipelines/openai/request.rs b/src/generation/functions/pipelines/openai/request.rs index 4a42cd3..b2870d9 100644 --- a/src/generation/functions/pipelines/openai/request.rs +++ b/src/generation/functions/pipelines/openai/request.rs @@ -38,7 +38,7 @@ impl OpenAIFunctionCall { model_name: String, tool_params: Value, tool: Arc, - ) -> Result { + ) -> Result { let result = tool.run(tool_params).await; match result { Ok(result) => Ok(ChatMessageResponse { @@ -48,7 +48,7 @@ impl OpenAIFunctionCall { done: true, final_data: None, }), - Err(e) => Err(self.error_handler(OllamaError::from(e))), + Err(e) => Err(OllamaError::from(e)), } } @@ -69,7 +69,7 @@ impl RequestParserBase for OpenAIFunctionCall { input: &str, model_name: String, tools: Vec>, - ) -> Result { + ) -> Result { let response_value: Result = serde_json::from_str(&self.clean_tool_call(input)); match response_value { @@ -85,11 +85,11 @@ impl RequestParserBase for OpenAIFunctionCall { .await?; return Ok(result); } else { - return Err(self.error_handler(OllamaError::from("Tool not found".to_string()))); + return Err(OllamaError::from("Tool not found".to_string())); } } Err(e) => { - return Err(self.error_handler(OllamaError::from(e))); + return Err(OllamaError::from(e)); } } } @@ -100,14 +100,4 @@ impl RequestParserBase for OpenAIFunctionCall { let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); ChatMessage::system(system_message_content) } - - fn error_handler(&self, error: OllamaError) -> ChatMessageResponse { - ChatMessageResponse { - model: "".to_string(), - created_at: "".to_string(), - message: Some(ChatMessage::assistant(error.to_string())), - done: true, - final_data: None, - } - } }