From b994be08ea8a9b94e6b0f8da02c4c8c72fa9dc09 Mon Sep 17 00:00:00 2001 From: pepperoni21 Date: Sat, 3 Aug 2024 14:47:50 +0200 Subject: [PATCH 1/3] Updated embeddings generation to /api/embed endpoint --- src/generation/embeddings.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/generation/embeddings.rs b/src/generation/embeddings.rs index f06eeb9..b6e1e3b 100644 --- a/src/generation/embeddings.rs +++ b/src/generation/embeddings.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; use crate::Ollama; -use super::options::GenerationOptions; +use super::{options::GenerationOptions, parameters::KeepAlive}; impl Ollama { /// Generate embeddings from a model @@ -16,11 +16,12 @@ impl Ollama { ) -> crate::error::Result { let request = GenerateEmbeddingsRequest { model_name, - prompt, + input: prompt, options, + ..Default::default() }; - let url = format!("{}api/embeddings", self.url_str()); + let url = format!("{}api/embed", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self .reqwest_client @@ -43,18 +44,19 @@ impl Ollama { } /// An embeddings generation request to Ollama. -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Default)] struct GenerateEmbeddingsRequest { #[serde(rename = "model")] model_name: String, - prompt: String, + input: String, + truncate: Option, options: Option, + keep_alive: Option, } /// An embeddings generation response from Ollama. #[derive(Debug, Deserialize, Clone)] pub struct GenerateEmbeddingsResponse { - #[serde(rename = "embedding")] #[allow(dead_code)] - pub embeddings: Vec, + pub embeddings: Vec>, } From 9c7cef5ac8c16585e9c4d7ba6eefe8849c0bde48 Mon Sep 17 00:00:00 2001 From: pepperoni21 Date: Sat, 3 Aug 2024 15:03:39 +0200 Subject: [PATCH 2/3] Separated request into the request module and added support for multiple inputs --- .../{embeddings.rs => embeddings/mod.rs} | 28 ++----- src/generation/embeddings/request.rs | 78 +++++++++++++++++++ tests/embeddings_generation.rs | 9 ++- 3 files changed, 88 insertions(+), 27 deletions(-) rename src/generation/{embeddings.rs => embeddings/mod.rs} (63%) create mode 100644 src/generation/embeddings/request.rs diff --git a/src/generation/embeddings.rs b/src/generation/embeddings/mod.rs similarity index 63% rename from src/generation/embeddings.rs rename to src/generation/embeddings/mod.rs index b6e1e3b..6c10b1a 100644 --- a/src/generation/embeddings.rs +++ b/src/generation/embeddings/mod.rs @@ -1,8 +1,10 @@ -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use crate::Ollama; -use super::{options::GenerationOptions, parameters::KeepAlive}; +use self::request::GenerateEmbeddingsRequest; + +pub mod request; impl Ollama { /// Generate embeddings from a model @@ -10,17 +12,8 @@ impl Ollama { /// * `prompt` - Prompt to generate embeddings for pub async fn generate_embeddings( &self, - model_name: String, - prompt: String, - options: Option, + request: GenerateEmbeddingsRequest, ) -> crate::error::Result { - let request = GenerateEmbeddingsRequest { - model_name, - input: prompt, - options, - ..Default::default() - }; - let url = format!("{}api/embed", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; let res = self @@ -43,17 +36,6 @@ impl Ollama { } } -/// An embeddings generation request to Ollama. -#[derive(Debug, Serialize, Default)] -struct GenerateEmbeddingsRequest { - #[serde(rename = "model")] - model_name: String, - input: String, - truncate: Option, - options: Option, - keep_alive: Option, -} - /// An embeddings generation response from Ollama. #[derive(Debug, Deserialize, Clone)] pub struct GenerateEmbeddingsResponse { diff --git a/src/generation/embeddings/request.rs b/src/generation/embeddings/request.rs new file mode 100644 index 0000000..1de4f38 --- /dev/null +++ b/src/generation/embeddings/request.rs @@ -0,0 +1,78 @@ +use serde::{Serialize, Serializer}; + +use crate::generation::{options::GenerationOptions, parameters::KeepAlive}; + +#[derive(Debug)] +pub enum EmbeddingsInput { + Single(String), + Multiple(Vec), +} + +impl Default for EmbeddingsInput { + fn default() -> Self { + Self::Single(String::default()) + } +} + +impl From for EmbeddingsInput { + fn from(s: String) -> Self { + Self::Single(s) + } +} + +impl From<&str> for EmbeddingsInput { + fn from(s: &str) -> Self { + Self::Single(s.to_string()) + } +} + +impl From> for EmbeddingsInput { + fn from(v: Vec) -> Self { + Self::Multiple(v) + } +} + +impl Serialize for EmbeddingsInput { + fn serialize(&self, serializer: S) -> Result { + match self { + EmbeddingsInput::Single(s) => s.serialize(serializer), + EmbeddingsInput::Multiple(v) => v.serialize(serializer), + } + } +} + +/// An embeddings generation request to Ollama. +#[derive(Debug, Serialize, Default)] +pub struct GenerateEmbeddingsRequest { + #[serde(rename = "model")] + model_name: String, + input: EmbeddingsInput, + truncate: Option, + options: Option, + keep_alive: Option, +} + +impl GenerateEmbeddingsRequest { + pub fn new(model_name: String, input: EmbeddingsInput) -> Self { + Self { + model_name, + input, + ..Default::default() + } + } + + pub fn options(mut self, options: GenerationOptions) -> Self { + self.options = Some(options); + self + } + + pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self { + self.keep_alive = Some(keep_alive); + self + } + + pub fn truncate(mut self, truncate: bool) -> Self { + self.truncate = Some(truncate); + self + } +} diff --git a/tests/embeddings_generation.rs b/tests/embeddings_generation.rs index d71546a..242e7ee 100644 --- a/tests/embeddings_generation.rs +++ b/tests/embeddings_generation.rs @@ -1,13 +1,14 @@ -use ollama_rs::Ollama; +use ollama_rs::{generation::embeddings::request::GenerateEmbeddingsRequest, Ollama}; #[tokio::test] async fn test_embeddings_generation() { let ollama = Ollama::default(); - let prompt = "Why is the sky blue?".to_string(); - let res = ollama - .generate_embeddings("llama2:latest".to_string(), prompt, None) + .generate_embeddings(GenerateEmbeddingsRequest::new( + "llama2:latest".to_string(), + "Why is the sky blue".into(), + )) .await .unwrap(); From 3443e963797c4f996e63e9608797e36da8e9695d Mon Sep 17 00:00:00 2001 From: pepperoni21 Date: Sat, 3 Aug 2024 15:15:51 +0200 Subject: [PATCH 3/3] Fixed batch embeddings and added test --- src/generation/embeddings/request.rs | 6 ++++++ tests/embeddings_generation.rs | 15 +++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/generation/embeddings/request.rs b/src/generation/embeddings/request.rs index 1de4f38..e44a0c5 100644 --- a/src/generation/embeddings/request.rs +++ b/src/generation/embeddings/request.rs @@ -32,6 +32,12 @@ impl From> for EmbeddingsInput { } } +impl From> for EmbeddingsInput { + fn from(v: Vec<&str>) -> Self { + Self::Multiple(v.iter().map(|s| s.to_string()).collect()) + } +} + impl Serialize for EmbeddingsInput { fn serialize(&self, serializer: S) -> Result { match self { diff --git a/tests/embeddings_generation.rs b/tests/embeddings_generation.rs index 242e7ee..ffe7b62 100644 --- a/tests/embeddings_generation.rs +++ b/tests/embeddings_generation.rs @@ -14,3 +14,18 @@ async fn test_embeddings_generation() { dbg!(res); } + +#[tokio::test] +async fn test_batch_embeddings_generation() { + let ollama = Ollama::default(); + + let res = ollama + .generate_embeddings(GenerateEmbeddingsRequest::new( + "llama2:latest".to_string(), + vec!["Why is the sky blue?", "Why is the sky red?"].into(), + )) + .await + .unwrap(); + + dbg!(res); +}