Skip to content

Commit

Permalink
Merge branch 'master' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
pepperoni21 authored Sep 2, 2024
2 parents 3c74dcd + 1900774 commit 7254c8e
Show file tree
Hide file tree
Showing 18 changed files with 764 additions and 32 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/target
.vscode/settings.json
shell.nix
.idea
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ _Returns a `GenerateEmbeddingsResponse` struct containing the embeddings (a vect
### Make a function call

```rust
let tools = vec![Arc::new(Scraper::new())];
let tools = vec![Arc::new(Scraper::new()), Arc::new(DDGSearcher::new())];
let parser = Arc::new(NousFunctionCall::new());
let message = ChatMessage::user("What is the current oil price?".to_string());
let res = ollama.send_function_call(
Expand Down
63 changes: 63 additions & 0 deletions examples/images_to_ollama.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use base64::Engine;
use ollama_rs::{
generation::{
completion::{request::GenerationRequest, GenerationResponse},
images::Image,
},
Ollama,
};
use reqwest::get;
use tokio::runtime::Runtime;

const IMAGE_URL: &str = "https://images.pexels.com/photos/1054655/pexels-photo-1054655.jpeg";
const PROMPT: &str = "Describe this image";

fn main() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
// Download the image and encode it to base64
let bytes = match download_image(IMAGE_URL).await {
Ok(b) => b,
Err(e) => {
eprintln!("Failed to download image: {}", e);
return;
}
};
let base64_image = base64::engine::general_purpose::STANDARD.encode(&bytes);

// Create an Image struct from the base64 string
let image = Image::from_base64(&base64_image);

// Create a GenerationRequest with the model and prompt, adding the image
let request =
GenerationRequest::new("llava:latest".to_string(), PROMPT.to_string()).add_image(image);

// Send the request to the model and get the response
let response = match send_request(request).await {
Ok(r) => r,
Err(e) => {
eprintln!("Failed to get response: {}", e);
return;
}
};

// Print the response
println!("{}", response.response);
});
}

// Function to download the image
async fn download_image(url: &str) -> Result<Vec<u8>, reqwest::Error> {
let response = get(url).await?;
let bytes = response.bytes().await?;
Ok(bytes.to_vec())
}

// Function to send the request to the model
async fn send_request(
request: GenerationRequest,
) -> Result<GenerationResponse, Box<dyn std::error::Error>> {
let ollama = Ollama::default();
let response = ollama.generate(request).await?;
Ok(response)
}
30 changes: 7 additions & 23 deletions src/generation/embeddings.rs → src/generation/embeddings/mod.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
use serde::{Deserialize, Serialize};
use serde::Deserialize;

use crate::Ollama;

use super::options::GenerationOptions;
use self::request::GenerateEmbeddingsRequest;

pub mod request;

impl Ollama {
/// Generate embeddings from a model
/// * `model_name` - Name of model to generate embeddings from
/// * `prompt` - Prompt to generate embeddings for
pub async fn generate_embeddings(
&self,
model_name: String,
prompt: String,
options: Option<GenerationOptions>,
request: GenerateEmbeddingsRequest,
) -> crate::error::Result<GenerateEmbeddingsResponse> {
let request = GenerateEmbeddingsRequest {
model_name,
prompt,
options,
};

let url = format!("{}api/embeddings", self.url_str());
let url = format!("{}api/embed", self.url_str());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
Expand All @@ -42,19 +36,9 @@ impl Ollama {
}
}

/// An embeddings generation request to Ollama.
#[derive(Debug, Serialize)]
struct GenerateEmbeddingsRequest {
#[serde(rename = "model")]
model_name: String,
prompt: String,
options: Option<GenerationOptions>,
}

/// An embeddings generation response from Ollama.
#[derive(Debug, Deserialize, Clone)]
pub struct GenerateEmbeddingsResponse {
#[serde(rename = "embedding")]
#[allow(dead_code)]
pub embeddings: Vec<f64>,
pub embeddings: Vec<Vec<f64>>,
}
84 changes: 84 additions & 0 deletions src/generation/embeddings/request.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use serde::{Serialize, Serializer};

use crate::generation::{options::GenerationOptions, parameters::KeepAlive};

#[derive(Debug)]
pub enum EmbeddingsInput {
Single(String),
Multiple(Vec<String>),
}

impl Default for EmbeddingsInput {
fn default() -> Self {
Self::Single(String::default())
}
}

impl From<String> for EmbeddingsInput {
fn from(s: String) -> Self {
Self::Single(s)
}
}

impl From<&str> for EmbeddingsInput {
fn from(s: &str) -> Self {
Self::Single(s.to_string())
}
}

impl From<Vec<String>> for EmbeddingsInput {
fn from(v: Vec<String>) -> Self {
Self::Multiple(v)
}
}

impl From<Vec<&str>> for EmbeddingsInput {
fn from(v: Vec<&str>) -> Self {
Self::Multiple(v.iter().map(|s| s.to_string()).collect())
}
}

impl Serialize for EmbeddingsInput {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
match self {
EmbeddingsInput::Single(s) => s.serialize(serializer),
EmbeddingsInput::Multiple(v) => v.serialize(serializer),
}
}
}

/// An embeddings generation request to Ollama.
#[derive(Debug, Serialize, Default)]
pub struct GenerateEmbeddingsRequest {
#[serde(rename = "model")]
model_name: String,
input: EmbeddingsInput,
truncate: Option<bool>,
options: Option<GenerationOptions>,
keep_alive: Option<KeepAlive>,
}

impl GenerateEmbeddingsRequest {
pub fn new(model_name: String, input: EmbeddingsInput) -> Self {
Self {
model_name,
input,
..Default::default()
}
}

pub fn options(mut self, options: GenerationOptions) -> Self {
self.options = Some(options);
self
}

pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
self.keep_alive = Some(keep_alive);
self
}

pub fn truncate(mut self, truncate: bool) -> Self {
self.truncate = Some(truncate);
self
}
}
3 changes: 3 additions & 0 deletions src/generation/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ pub mod pipelines;
pub mod request;
pub mod tools;

pub use crate::generation::functions::pipelines::meta_llama::request::LlamaFunctionCall;
pub use crate::generation::functions::pipelines::nous_hermes::request::NousFunctionCall;
pub use crate::generation::functions::pipelines::openai::request::OpenAIFunctionCall;
pub use crate::generation::functions::request::FunctionCallRequest;
pub use tools::Browserless;
pub use tools::DDGSearcher;
pub use tools::Scraper;
pub use tools::SerperSearchTool;
pub use tools::StockScraper;

use crate::error::OllamaError;
Expand Down
4 changes: 4 additions & 0 deletions src/generation/functions/pipelines/meta_llama/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub mod prompts;
pub mod request;

pub use prompts::DEFAULT_SYSTEM_TEMPLATE;
14 changes: 14 additions & 0 deletions src/generation/functions/pipelines/meta_llama/prompts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#"
You have access to the following functions:
{tools}
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>
Reminder:
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
"#;
136 changes: 136 additions & 0 deletions src/generation/functions/pipelines/meta_llama/request.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use crate::error::OllamaError;
use crate::generation::chat::{ChatMessage, ChatMessageResponse};
use crate::generation::functions::pipelines::meta_llama::DEFAULT_SYSTEM_TEMPLATE;
use crate::generation::functions::pipelines::RequestParserBase;
use crate::generation::functions::tools::Tool;
use async_trait::async_trait;
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;

pub fn convert_to_llama_tool(tool: &Arc<dyn Tool>) -> Value {
let mut function = HashMap::new();
function.insert("name".to_string(), Value::String(tool.name()));
function.insert("description".to_string(), Value::String(tool.description()));
function.insert("parameters".to_string(), tool.parameters());
json!(format!(
"Use the function '{name}' to '{description}': {json}",
name = tool.name(),
description = tool.description(),
json = json!(function)
))
}

#[derive(Debug, Deserialize, Serialize)]
pub struct LlamaFunctionCallSignature {
pub function: String, //name of the tool
pub arguments: Value,
}

pub struct LlamaFunctionCall {}

impl LlamaFunctionCall {
pub async fn function_call_with_history(
&self,
model_name: String,
tool_params: Value,
tool: Arc<dyn Tool>,
) -> Result<ChatMessageResponse, ChatMessageResponse> {
let result = tool.run(tool_params).await;
match result {
Ok(result) => Ok(ChatMessageResponse {
model: model_name.clone(),
created_at: "".to_string(),
message: Some(ChatMessage::assistant(result.to_string())),
done: true,
final_data: None,
}),
Err(e) => Err(self.error_handler(OllamaError::from(e))),
}
}

fn clean_tool_call(&self, json_str: &str) -> String {
json_str
.trim()
.trim_start_matches("```json")
.trim_end_matches("```")
.trim()
.to_string()
.replace("{{", "{")
.replace("}}", "}")
}

fn parse_tool_response(&self, response: &str) -> Option<LlamaFunctionCallSignature> {
let function_regex = Regex::new(r"<function=(\w+)>(.*?)</function>").unwrap();
println!("Response: {}", response);
if let Some(caps) = function_regex.captures(response) {
let function_name = caps.get(1).unwrap().as_str().to_string();
let args_string = caps.get(2).unwrap().as_str();

match serde_json::from_str(args_string) {
Ok(arguments) => Some(LlamaFunctionCallSignature {
function: function_name,
arguments,
}),
Err(error) => {
println!("Error parsing function arguments: {}", error);
None
}
}
} else {
None
}
}
}

#[async_trait]
impl RequestParserBase for LlamaFunctionCall {
async fn parse(
&self,
input: &str,
model_name: String,
tools: Vec<Arc<dyn Tool>>,
) -> Result<ChatMessageResponse, ChatMessageResponse> {
let response_value = self.parse_tool_response(&self.clean_tool_call(input));
match response_value {
Some(response) => {
if let Some(tool) = tools.iter().find(|t| t.name() == response.function) {
let tool_params = response.arguments;
let result = self
.function_call_with_history(
model_name.clone(),
tool_params.clone(),
tool.clone(),
)
.await?;
return Ok(result);
} else {
return Err(self.error_handler(OllamaError::from("Tool not found".to_string())));
}
}
None => {
return Err(self
.error_handler(OllamaError::from("Error parsing function call".to_string())));
}
}
}

async fn get_system_message(&self, tools: &[Arc<dyn Tool>]) -> ChatMessage {
let tools_info: Vec<Value> = tools.iter().map(convert_to_llama_tool).collect();
let tools_json = serde_json::to_string(&tools_info).unwrap();
let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json);
ChatMessage::system(system_message_content)
}

fn error_handler(&self, error: OllamaError) -> ChatMessageResponse {
ChatMessageResponse {
model: "".to_string(),
created_at: "".to_string(),
message: Some(ChatMessage::assistant(error.to_string())),
done: true,
final_data: None,
}
}
}
Loading

0 comments on commit 7254c8e

Please sign in to comment.