Skip to content

Commit

Permalink
feat(integrations): Add ollama embeddings support (#278)
Browse files Browse the repository at this point in the history
Update to the most recent ollama-rs, which exposes the batch embedding
API Ollama exposes (pepperoni21/ollama-rs#61).
This allows the Ollama struct in Swiftide to implement `EmbeddingModel`.

Use the same pattern that the OpenAI struct uses to manage separate
embedding and prompt models.

---------

Co-authored-by: Timon Vonk <[email protected]>
  • Loading branch information
ephraimkunz and timonv authored Sep 8, 2024
1 parent bdf17ad commit a98dbcb
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 13 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion swiftide-integrations/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ aws-sdk-bedrockruntime = { version = "1.37", features = [
], optional = true }
secrecy = { version = "0.8.0", optional = true }
reqwest = { version = "0.12.5", optional = true, default-features = false }
ollama-rs = { version = "0.2.0", optional = true }
ollama-rs = { version = "0.2.1", optional = true }
deadpool = { version = "0.12", optional = true, features = [
"managed",
"rt_tokio_1",
Expand Down
33 changes: 33 additions & 0 deletions swiftide-integrations/src/ollama/embed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use anyhow::{Context as _, Result};
use async_trait::async_trait;

use ollama_rs::generation::embeddings::request::GenerateEmbeddingsRequest;
use swiftide_core::{EmbeddingModel, Embeddings};

use super::Ollama;

#[async_trait]
impl EmbeddingModel for Ollama {
async fn embed(&self, input: Vec<String>) -> Result<Embeddings> {
let model = self
.default_options
.embed_model
.as_ref()
.context("Model not set")?;

let request = GenerateEmbeddingsRequest::new(model.to_string(), input.into());
tracing::debug!(
messages = serde_json::to_string_pretty(&request)?,
"[Embed] Request to ollama"
);
let response = self
.client
.generate_embeddings(request)
.await
.context("Request to Ollama Failed")?;

tracing::debug!("[Embed] Response ollama");

Ok(response.embeddings)
}
}
114 changes: 104 additions & 10 deletions swiftide-integrations/src/ollama/mod.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
//! This module provides integration with `Ollama`'s API, enabling the use of language models within the Swiftide project.
//! It includes the `Ollama` struct for managing API clients and default options for prompt models.
//! This module provides integration with `Ollama`'s API, enabling the use of language models and embeddings within the Swiftide project.
//! It includes the `Ollama` struct for managing API clients and default options for embedding and prompt models.
//! The module is conditionally compiled based on the "ollama" feature flag.
use derive_builder::Builder;
use std::sync::Arc;

mod embed;
mod simple_prompt;

/// The `Ollama` struct encapsulates a `Ollama` client that implements [`swiftide::traits::SimplePrompt`]
/// The `Ollama` struct encapsulates an `Ollama` client and default options for embedding and prompt models.
/// It uses the `Builder` pattern for flexible and customizable instantiation.
///
/// There is also a builder available.
///
/// By default it will look for a `OLLAMA_API_KEY` environment variable. Note that a model
/// always needs to be set, either with [`Ollama::with_default_prompt_model`] or via the builder.
/// By default it will look for a `OLLAMA_API_KEY` environment variable. Note that either a prompt model or embedding model
/// always need to be set, either with [`Ollama::with_default_prompt_model`] or [`Ollama::with_default_embed_model`] or via the builder.
/// You can find available models in the Ollama documentation.
///
/// Under the hood it uses [`async_openai`], with the Ollama openai mapping. This means
Expand All @@ -23,7 +23,7 @@ pub struct Ollama {
/// The `Ollama` client, wrapped in an `Arc` for thread-safe reference counting.
#[builder(default = "default_client()", setter(custom))]
client: Arc<ollama_rs::Ollama>,
/// Default options for prompt models.
/// Default options for the embedding and prompt models.
#[builder(default)]
default_options: Options,
}
Expand All @@ -38,10 +38,14 @@ impl Default for Ollama {
}

/// The `Options` struct holds configuration options for the `Ollama` client.
/// It includes optional fields for specifying the prompt model.
/// It includes optional fields for specifying the embedding and prompt models.
#[derive(Debug, Default, Clone, Builder)]
#[builder(setter(into, strip_option))]
pub struct Options {
/// The default embedding model to use, if specified.
#[builder(default)]
pub embed_model: Option<String>,

/// The default prompt model to use, if specified.
#[builder(default)]
pub prompt_model: Option<String>,
Expand All @@ -64,6 +68,16 @@ impl Ollama {
pub fn with_default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
self.default_options = Options {
prompt_model: Some(model.into()),
embed_model: self.default_options.embed_model.clone(),
};
self
}

/// Sets a default embedding model to use when embedding
pub fn with_default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
self.default_options = Options {
prompt_model: self.default_options.prompt_model.clone(),
embed_model: Some(model.into()),
};
self
}
Expand All @@ -82,6 +96,25 @@ impl OllamaBuilder {
self
}

/// Sets the default embedding model for the `Ollama` instance.
///
/// # Parameters
/// - `model`: The embedding model to set.
///
/// # Returns
/// A mutable reference to the `OllamaBuilder`.
pub fn default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
if let Some(options) = self.default_options.as_mut() {
options.embed_model = Some(model.into());
} else {
self.default_options = Some(Options {
embed_model: Some(model.into()),
..Default::default()
});
}
self
}

/// Sets the default prompt model for the `Ollama` instance.
///
/// # Parameters
Expand All @@ -95,6 +128,7 @@ impl OllamaBuilder {
} else {
self.default_options = Some(Options {
prompt_model: Some(model.into()),
..Default::default()
});
}
self
Expand Down Expand Up @@ -122,7 +156,36 @@ mod test {
}

#[test]
fn test_building_via_default() {
fn test_default_embed_model() {
let ollama = Ollama::builder()
.default_embed_model("mxbai-embed-large")
.build()
.unwrap();
assert_eq!(
ollama.default_options.embed_model,
Some("mxbai-embed-large".to_string())
);
}

#[test]
fn test_default_models() {
let ollama = Ollama::builder()
.default_embed_model("mxbai-embed-large")
.default_prompt_model("llama3.1")
.build()
.unwrap();
assert_eq!(
ollama.default_options.embed_model,
Some("mxbai-embed-large".to_string())
);
assert_eq!(
ollama.default_options.prompt_model,
Some("llama3.1".to_string())
);
}

#[test]
fn test_building_via_default_prompt_model() {
let mut client = Ollama::default();

assert!(client.default_options.prompt_model.is_none());
Expand All @@ -133,4 +196,35 @@ mod test {
Some("llama3.1".to_string())
);
}

#[test]
fn test_building_via_default_embed_model() {
let mut client = Ollama::default();

assert!(client.default_options.embed_model.is_none());

client.with_default_embed_model("mxbai-embed-large");
assert_eq!(
client.default_options.embed_model,
Some("mxbai-embed-large".to_string())
);
}

#[test]
fn test_building_via_default_models() {
let mut client = Ollama::default();

assert!(client.default_options.embed_model.is_none());

client.with_default_prompt_model("llama3.1");
client.with_default_embed_model("mxbai-embed-large");
assert_eq!(
client.default_options.prompt_model,
Some("llama3.1".to_string())
);
assert_eq!(
client.default_options.embed_model,
Some("mxbai-embed-large".to_string())
);
}
}

0 comments on commit a98dbcb

Please sign in to comment.