Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add utility methods to simplify InMemoryVectorStore creation #32

Merged
merged 2 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions rig-core/examples/vector_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::env;
use rig::{
embeddings::EmbeddingsBuilder,
providers::openai::Client,
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex},
vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex},
};

#[tokio::main]
Expand All @@ -14,18 +14,14 @@ async fn main() -> Result<(), anyhow::Error> {

let model = openai_client.embedding_model("text-embedding-ada-002");

let mut vector_store = InMemoryVectorStore::default();

let embeddings = EmbeddingsBuilder::new(model.clone())
.simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
.build()
.await?;

vector_store.add_documents(embeddings).await?;

let index = vector_store.index(model);
let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?;

let results = index
.top_n_from_query("What is a linglingdong?", 1)
Expand Down
40 changes: 38 additions & 2 deletions rig-core/src/vector_store/in_memory_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,41 @@ impl InMemoryVectorStore {
pub fn is_empty(&self) -> bool {
self.embeddings.is_empty()
}

/// Uitilty method to create an InMemoryVectorStore from a list of embeddings.
pub async fn from_embeddings(
embeddings: Vec<DocumentEmbeddings>,
) -> Result<Self, VectorStoreError> {
let mut store = Self::default();
store.add_documents(embeddings).await?;
Ok(store)
}

/// Create an InMemoryVectorStore from a list of documents.
/// The documents are serialized to JSON and embedded using the provided embedding model.
/// The resulting embeddings are stored in an InMemoryVectorStore created by the method.
pub async fn from_documents<M: EmbeddingModel, T: Serialize>(
embedding_model: M,
documents: &[(String, T)],
) -> Result<Self, VectorStoreError> {
let embeddings = documents
.iter()
.fold(
EmbeddingsBuilder::new(embedding_model),
|builder, (id, doc)| {
builder.json_document(
id,
serde_json::to_value(doc).expect("Document should be serializable"),
vec![serde_json::to_string(doc).expect("Document should be serializable")],
)
},
)
.build()
.await?;

let store = Self::from_embeddings(embeddings).await?;
Ok(store)
}
}

pub struct InMemoryVectorIndex<M: EmbeddingModel> {
Expand Down Expand Up @@ -151,12 +186,13 @@ impl<M: EmbeddingModel> InMemoryVectorIndex<M> {
Ok(store.index(query_model))
}

/// Utility method to create an InMemoryVectorIndex from a list of embeddings
/// and an embedding model.
pub async fn from_embeddings(
query_model: M,
embeddings: Vec<DocumentEmbeddings>,
) -> Result<Self, VectorStoreError> {
let mut store = InMemoryVectorStore::default();
store.add_documents(embeddings).await?;
let store = InMemoryVectorStore::from_embeddings(embeddings).await?;
Ok(store.index(query_model))
}
}
Expand Down