diff --git a/Cargo.lock b/Cargo.lock index 0672563e..89a80318 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5418,10 +5418,11 @@ dependencies = [ [[package]] name = "ollama-rs" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "255252ec57e13d2d6ae074c7b7cd8c004d17dafb1e03f954ba2fd5cc226f8f49" +checksum = "46483ac9e1f9e93da045b5875837ca3c9cf014fd6ab89b4d9736580ddefc4759" dependencies = [ + "async-stream", "async-trait", "log", "reqwest", diff --git a/swiftide-integrations/Cargo.toml b/swiftide-integrations/Cargo.toml index 1f202cc2..786d381e 100644 --- a/swiftide-integrations/Cargo.toml +++ b/swiftide-integrations/Cargo.toml @@ -60,7 +60,7 @@ aws-sdk-bedrockruntime = { version = "1.37", features = [ ], optional = true } secrecy = { version = "0.8.0", optional = true } reqwest = { version = "0.12.5", optional = true, default-features = false } -ollama-rs = { version = "0.2.0", optional = true } +ollama-rs = { version = "0.2.1", optional = true } deadpool = { version = "0.12", optional = true, features = [ "managed", "rt_tokio_1", diff --git a/swiftide-integrations/src/ollama/embed.rs b/swiftide-integrations/src/ollama/embed.rs new file mode 100644 index 00000000..c9d71bd9 --- /dev/null +++ b/swiftide-integrations/src/ollama/embed.rs @@ -0,0 +1,33 @@ +use anyhow::{Context as _, Result}; +use async_trait::async_trait; + +use ollama_rs::generation::embeddings::request::GenerateEmbeddingsRequest; +use swiftide_core::{EmbeddingModel, Embeddings}; + +use super::Ollama; + +#[async_trait] +impl EmbeddingModel for Ollama { + async fn embed(&self, input: Vec) -> Result { + let model = self + .default_options + .embed_model + .as_ref() + .context("Model not set")?; + + let request = GenerateEmbeddingsRequest::new(model.to_string(), input.into()); + tracing::debug!( + messages = serde_json::to_string_pretty(&request)?, + "[Embed] Request to ollama" + ); + let response = self + .client + .generate_embeddings(request) + .await + .context("Request to Ollama Failed")?; + + tracing::debug!("[Embed] Response ollama"); + + Ok(response.embeddings) + } +} diff --git a/swiftide-integrations/src/ollama/mod.rs b/swiftide-integrations/src/ollama/mod.rs index 98847e1a..5dd48786 100644 --- a/swiftide-integrations/src/ollama/mod.rs +++ b/swiftide-integrations/src/ollama/mod.rs @@ -1,18 +1,18 @@ -//! This module provides integration with `Ollama`'s API, enabling the use of language models within the Swiftide project. -//! It includes the `Ollama` struct for managing API clients and default options for prompt models. +//! This module provides integration with `Ollama`'s API, enabling the use of language models and embeddings within the Swiftide project. +//! It includes the `Ollama` struct for managing API clients and default options for embedding and prompt models. //! The module is conditionally compiled based on the "ollama" feature flag. use derive_builder::Builder; use std::sync::Arc; +mod embed; mod simple_prompt; -/// The `Ollama` struct encapsulates a `Ollama` client that implements [`swiftide::traits::SimplePrompt`] +/// The `Ollama` struct encapsulates an `Ollama` client and default options for embedding and prompt models. +/// It uses the `Builder` pattern for flexible and customizable instantiation. /// -/// There is also a builder available. -/// -/// By default it will look for a `OLLAMA_API_KEY` environment variable. Note that a model -/// always needs to be set, either with [`Ollama::with_default_prompt_model`] or via the builder. +/// By default it will look for a `OLLAMA_API_KEY` environment variable. Note that either a prompt model or embedding model +/// always need to be set, either with [`Ollama::with_default_prompt_model`] or [`Ollama::with_default_embed_model`] or via the builder. /// You can find available models in the Ollama documentation. /// /// Under the hood it uses [`async_openai`], with the Ollama openai mapping. This means @@ -23,7 +23,7 @@ pub struct Ollama { /// The `Ollama` client, wrapped in an `Arc` for thread-safe reference counting. #[builder(default = "default_client()", setter(custom))] client: Arc, - /// Default options for prompt models. + /// Default options for the embedding and prompt models. #[builder(default)] default_options: Options, } @@ -38,10 +38,14 @@ impl Default for Ollama { } /// The `Options` struct holds configuration options for the `Ollama` client. -/// It includes optional fields for specifying the prompt model. +/// It includes optional fields for specifying the embedding and prompt models. #[derive(Debug, Default, Clone, Builder)] #[builder(setter(into, strip_option))] pub struct Options { + /// The default embedding model to use, if specified. + #[builder(default)] + pub embed_model: Option, + /// The default prompt model to use, if specified. #[builder(default)] pub prompt_model: Option, @@ -64,6 +68,16 @@ impl Ollama { pub fn with_default_prompt_model(&mut self, model: impl Into) -> &mut Self { self.default_options = Options { prompt_model: Some(model.into()), + embed_model: self.default_options.embed_model.clone(), + }; + self + } + + /// Sets a default embedding model to use when embedding + pub fn with_default_embed_model(&mut self, model: impl Into) -> &mut Self { + self.default_options = Options { + prompt_model: self.default_options.prompt_model.clone(), + embed_model: Some(model.into()), }; self } @@ -82,6 +96,25 @@ impl OllamaBuilder { self } + /// Sets the default embedding model for the `Ollama` instance. + /// + /// # Parameters + /// - `model`: The embedding model to set. + /// + /// # Returns + /// A mutable reference to the `OllamaBuilder`. + pub fn default_embed_model(&mut self, model: impl Into) -> &mut Self { + if let Some(options) = self.default_options.as_mut() { + options.embed_model = Some(model.into()); + } else { + self.default_options = Some(Options { + embed_model: Some(model.into()), + ..Default::default() + }); + } + self + } + /// Sets the default prompt model for the `Ollama` instance. /// /// # Parameters @@ -95,6 +128,7 @@ impl OllamaBuilder { } else { self.default_options = Some(Options { prompt_model: Some(model.into()), + ..Default::default() }); } self @@ -122,7 +156,36 @@ mod test { } #[test] - fn test_building_via_default() { + fn test_default_embed_model() { + let ollama = Ollama::builder() + .default_embed_model("mxbai-embed-large") + .build() + .unwrap(); + assert_eq!( + ollama.default_options.embed_model, + Some("mxbai-embed-large".to_string()) + ); + } + + #[test] + fn test_default_models() { + let ollama = Ollama::builder() + .default_embed_model("mxbai-embed-large") + .default_prompt_model("llama3.1") + .build() + .unwrap(); + assert_eq!( + ollama.default_options.embed_model, + Some("mxbai-embed-large".to_string()) + ); + assert_eq!( + ollama.default_options.prompt_model, + Some("llama3.1".to_string()) + ); + } + + #[test] + fn test_building_via_default_prompt_model() { let mut client = Ollama::default(); assert!(client.default_options.prompt_model.is_none()); @@ -133,4 +196,35 @@ mod test { Some("llama3.1".to_string()) ); } + + #[test] + fn test_building_via_default_embed_model() { + let mut client = Ollama::default(); + + assert!(client.default_options.embed_model.is_none()); + + client.with_default_embed_model("mxbai-embed-large"); + assert_eq!( + client.default_options.embed_model, + Some("mxbai-embed-large".to_string()) + ); + } + + #[test] + fn test_building_via_default_models() { + let mut client = Ollama::default(); + + assert!(client.default_options.embed_model.is_none()); + + client.with_default_prompt_model("llama3.1"); + client.with_default_embed_model("mxbai-embed-large"); + assert_eq!( + client.default_options.prompt_model, + Some("llama3.1".to_string()) + ); + assert_eq!( + client.default_options.embed_model, + Some("mxbai-embed-large".to_string()) + ); + } }