-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
18 changed files
with
764 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
/target | ||
.vscode/settings.json | ||
shell.nix | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
use base64::Engine; | ||
use ollama_rs::{ | ||
generation::{ | ||
completion::{request::GenerationRequest, GenerationResponse}, | ||
images::Image, | ||
}, | ||
Ollama, | ||
}; | ||
use reqwest::get; | ||
use tokio::runtime::Runtime; | ||
|
||
const IMAGE_URL: &str = "https://images.pexels.com/photos/1054655/pexels-photo-1054655.jpeg"; | ||
const PROMPT: &str = "Describe this image"; | ||
|
||
fn main() { | ||
let rt = Runtime::new().unwrap(); | ||
rt.block_on(async { | ||
// Download the image and encode it to base64 | ||
let bytes = match download_image(IMAGE_URL).await { | ||
Ok(b) => b, | ||
Err(e) => { | ||
eprintln!("Failed to download image: {}", e); | ||
return; | ||
} | ||
}; | ||
let base64_image = base64::engine::general_purpose::STANDARD.encode(&bytes); | ||
|
||
// Create an Image struct from the base64 string | ||
let image = Image::from_base64(&base64_image); | ||
|
||
// Create a GenerationRequest with the model and prompt, adding the image | ||
let request = | ||
GenerationRequest::new("llava:latest".to_string(), PROMPT.to_string()).add_image(image); | ||
|
||
// Send the request to the model and get the response | ||
let response = match send_request(request).await { | ||
Ok(r) => r, | ||
Err(e) => { | ||
eprintln!("Failed to get response: {}", e); | ||
return; | ||
} | ||
}; | ||
|
||
// Print the response | ||
println!("{}", response.response); | ||
}); | ||
} | ||
|
||
// Function to download the image | ||
async fn download_image(url: &str) -> Result<Vec<u8>, reqwest::Error> { | ||
let response = get(url).await?; | ||
let bytes = response.bytes().await?; | ||
Ok(bytes.to_vec()) | ||
} | ||
|
||
// Function to send the request to the model | ||
async fn send_request( | ||
request: GenerationRequest, | ||
) -> Result<GenerationResponse, Box<dyn std::error::Error>> { | ||
let ollama = Ollama::default(); | ||
let response = ollama.generate(request).await?; | ||
Ok(response) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
use serde::{Serialize, Serializer}; | ||
|
||
use crate::generation::{options::GenerationOptions, parameters::KeepAlive}; | ||
|
||
#[derive(Debug)] | ||
pub enum EmbeddingsInput { | ||
Single(String), | ||
Multiple(Vec<String>), | ||
} | ||
|
||
impl Default for EmbeddingsInput { | ||
fn default() -> Self { | ||
Self::Single(String::default()) | ||
} | ||
} | ||
|
||
impl From<String> for EmbeddingsInput { | ||
fn from(s: String) -> Self { | ||
Self::Single(s) | ||
} | ||
} | ||
|
||
impl From<&str> for EmbeddingsInput { | ||
fn from(s: &str) -> Self { | ||
Self::Single(s.to_string()) | ||
} | ||
} | ||
|
||
impl From<Vec<String>> for EmbeddingsInput { | ||
fn from(v: Vec<String>) -> Self { | ||
Self::Multiple(v) | ||
} | ||
} | ||
|
||
impl From<Vec<&str>> for EmbeddingsInput { | ||
fn from(v: Vec<&str>) -> Self { | ||
Self::Multiple(v.iter().map(|s| s.to_string()).collect()) | ||
} | ||
} | ||
|
||
impl Serialize for EmbeddingsInput { | ||
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { | ||
match self { | ||
EmbeddingsInput::Single(s) => s.serialize(serializer), | ||
EmbeddingsInput::Multiple(v) => v.serialize(serializer), | ||
} | ||
} | ||
} | ||
|
||
/// An embeddings generation request to Ollama. | ||
#[derive(Debug, Serialize, Default)] | ||
pub struct GenerateEmbeddingsRequest { | ||
#[serde(rename = "model")] | ||
model_name: String, | ||
input: EmbeddingsInput, | ||
truncate: Option<bool>, | ||
options: Option<GenerationOptions>, | ||
keep_alive: Option<KeepAlive>, | ||
} | ||
|
||
impl GenerateEmbeddingsRequest { | ||
pub fn new(model_name: String, input: EmbeddingsInput) -> Self { | ||
Self { | ||
model_name, | ||
input, | ||
..Default::default() | ||
} | ||
} | ||
|
||
pub fn options(mut self, options: GenerationOptions) -> Self { | ||
self.options = Some(options); | ||
self | ||
} | ||
|
||
pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self { | ||
self.keep_alive = Some(keep_alive); | ||
self | ||
} | ||
|
||
pub fn truncate(mut self, truncate: bool) -> Self { | ||
self.truncate = Some(truncate); | ||
self | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
pub mod prompts; | ||
pub mod request; | ||
|
||
pub use prompts::DEFAULT_SYSTEM_TEMPLATE; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#" | ||
You have access to the following functions: | ||
{tools} | ||
If you choose to call a function ONLY reply in the following format with no prefix or suffix: | ||
<function=example_function_name>{{\"example_name\": \"example_value\"}}</function> | ||
Reminder: | ||
- Function calls MUST follow the specified format, start with <function= and end with </function> | ||
- Required parameters MUST be specified | ||
- Only call one function at a time | ||
- Put the entire function call reply on one line | ||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls | ||
"#; |
136 changes: 136 additions & 0 deletions
136
src/generation/functions/pipelines/meta_llama/request.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
use crate::error::OllamaError; | ||
use crate::generation::chat::{ChatMessage, ChatMessageResponse}; | ||
use crate::generation::functions::pipelines::meta_llama::DEFAULT_SYSTEM_TEMPLATE; | ||
use crate::generation::functions::pipelines::RequestParserBase; | ||
use crate::generation::functions::tools::Tool; | ||
use async_trait::async_trait; | ||
use regex::Regex; | ||
use serde::{Deserialize, Serialize}; | ||
use serde_json::{json, Value}; | ||
use std::collections::HashMap; | ||
use std::sync::Arc; | ||
|
||
pub fn convert_to_llama_tool(tool: &Arc<dyn Tool>) -> Value { | ||
let mut function = HashMap::new(); | ||
function.insert("name".to_string(), Value::String(tool.name())); | ||
function.insert("description".to_string(), Value::String(tool.description())); | ||
function.insert("parameters".to_string(), tool.parameters()); | ||
json!(format!( | ||
"Use the function '{name}' to '{description}': {json}", | ||
name = tool.name(), | ||
description = tool.description(), | ||
json = json!(function) | ||
)) | ||
} | ||
|
||
#[derive(Debug, Deserialize, Serialize)] | ||
pub struct LlamaFunctionCallSignature { | ||
pub function: String, //name of the tool | ||
pub arguments: Value, | ||
} | ||
|
||
pub struct LlamaFunctionCall {} | ||
|
||
impl LlamaFunctionCall { | ||
pub async fn function_call_with_history( | ||
&self, | ||
model_name: String, | ||
tool_params: Value, | ||
tool: Arc<dyn Tool>, | ||
) -> Result<ChatMessageResponse, ChatMessageResponse> { | ||
let result = tool.run(tool_params).await; | ||
match result { | ||
Ok(result) => Ok(ChatMessageResponse { | ||
model: model_name.clone(), | ||
created_at: "".to_string(), | ||
message: Some(ChatMessage::assistant(result.to_string())), | ||
done: true, | ||
final_data: None, | ||
}), | ||
Err(e) => Err(self.error_handler(OllamaError::from(e))), | ||
} | ||
} | ||
|
||
fn clean_tool_call(&self, json_str: &str) -> String { | ||
json_str | ||
.trim() | ||
.trim_start_matches("```json") | ||
.trim_end_matches("```") | ||
.trim() | ||
.to_string() | ||
.replace("{{", "{") | ||
.replace("}}", "}") | ||
} | ||
|
||
fn parse_tool_response(&self, response: &str) -> Option<LlamaFunctionCallSignature> { | ||
let function_regex = Regex::new(r"<function=(\w+)>(.*?)</function>").unwrap(); | ||
println!("Response: {}", response); | ||
if let Some(caps) = function_regex.captures(response) { | ||
let function_name = caps.get(1).unwrap().as_str().to_string(); | ||
let args_string = caps.get(2).unwrap().as_str(); | ||
|
||
match serde_json::from_str(args_string) { | ||
Ok(arguments) => Some(LlamaFunctionCallSignature { | ||
function: function_name, | ||
arguments, | ||
}), | ||
Err(error) => { | ||
println!("Error parsing function arguments: {}", error); | ||
None | ||
} | ||
} | ||
} else { | ||
None | ||
} | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl RequestParserBase for LlamaFunctionCall { | ||
async fn parse( | ||
&self, | ||
input: &str, | ||
model_name: String, | ||
tools: Vec<Arc<dyn Tool>>, | ||
) -> Result<ChatMessageResponse, ChatMessageResponse> { | ||
let response_value = self.parse_tool_response(&self.clean_tool_call(input)); | ||
match response_value { | ||
Some(response) => { | ||
if let Some(tool) = tools.iter().find(|t| t.name() == response.function) { | ||
let tool_params = response.arguments; | ||
let result = self | ||
.function_call_with_history( | ||
model_name.clone(), | ||
tool_params.clone(), | ||
tool.clone(), | ||
) | ||
.await?; | ||
return Ok(result); | ||
} else { | ||
return Err(self.error_handler(OllamaError::from("Tool not found".to_string()))); | ||
} | ||
} | ||
None => { | ||
return Err(self | ||
.error_handler(OllamaError::from("Error parsing function call".to_string()))); | ||
} | ||
} | ||
} | ||
|
||
async fn get_system_message(&self, tools: &[Arc<dyn Tool>]) -> ChatMessage { | ||
let tools_info: Vec<Value> = tools.iter().map(convert_to_llama_tool).collect(); | ||
let tools_json = serde_json::to_string(&tools_info).unwrap(); | ||
let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); | ||
ChatMessage::system(system_message_content) | ||
} | ||
|
||
fn error_handler(&self, error: OllamaError) -> ChatMessageResponse { | ||
ChatMessageResponse { | ||
model: "".to_string(), | ||
created_at: "".to_string(), | ||
message: Some(ChatMessage::assistant(error.to_string())), | ||
done: true, | ||
final_data: None, | ||
} | ||
} | ||
} |
Oops, something went wrong.