Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Improve InMemoryVectorStore API #130

Merged
merged 3 commits into from
Nov 29, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
style: clippy+fmt
  • Loading branch information
cvauclair committed Nov 29, 2024
commit 564b9a63b311e4642d3497c2bdd338d24b1877a8
3 changes: 2 additions & 1 deletion rig-core/examples/calculator_chatbot.rs
Original file line number Diff line number Diff line change
@@ -251,7 +251,8 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;

let vector_store = InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone());
let vector_store =
InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone());
let index = vector_store.index(embedding_model);

// Create RAG agent with a single context prompt and a dynamic tool source
3 changes: 2 additions & 1 deletion rig-core/examples/rag_dynamic_tools.rs
Original file line number Diff line number Diff line change
@@ -156,7 +156,8 @@ async fn main() -> Result<(), anyhow::Error> {
.await?;

// Create vector store with the embeddings
let vector_store = InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone());
let vector_store =
InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone());

// Create vector store index
let index = vector_store.index(embedding_model);
6 changes: 2 additions & 4 deletions rig-core/examples/vector_search.rs
Original file line number Diff line number Diff line change
@@ -57,10 +57,8 @@ async fn main() -> Result<(), anyhow::Error> {
.await?;

// Create vector store with the embeddings
let vector_store = InMemoryVectorStore::from_documents_with_id_f(
embeddings,
|doc| doc.id.clone(),
);
let vector_store =
InMemoryVectorStore::from_documents_with_id_f(embeddings, |doc| doc.id.clone());

// Create vector store index
let index = vector_store.index(embedding_model);
6 changes: 2 additions & 4 deletions rig-core/examples/vector_search_cohere.rs
Original file line number Diff line number Diff line change
@@ -58,10 +58,8 @@ async fn main() -> Result<(), anyhow::Error> {
.await?;

// Create vector store with the embeddings
let vector_store = InMemoryVectorStore::from_documents_with_id_f(
embeddings,
|doc| doc.id.clone()
);
let vector_store =
InMemoryVectorStore::from_documents_with_id_f(embeddings, |doc| doc.id.clone());

// Create vector store index
let index = vector_store.index(search_model);
137 changes: 69 additions & 68 deletions rig-core/src/vector_store/in_memory_store.rs
Original file line number Diff line number Diff line change
@@ -25,11 +25,12 @@

impl<D: Serialize + Eq> InMemoryVectorStore<D> {
/// Create a new [InMemoryVectorStore] from documents and their corresponding embeddings.
/// Ids are automatically generated have will have the form `"doc{n}"` where `n`
/// Ids are automatically generated have will have the form `"doc{n}"` where `n`
/// is the index of the document.
pub fn from_documents(documents: impl IntoIterator<Item = (D, OneOrMany<Embedding>)>) -> Self {
let mut store = HashMap::new();
documents.into_iter()
documents
.into_iter()
.enumerate()
.for_each(|(i, (doc, embeddings))| {
store.insert(format!("doc{i}"), (doc, embeddings));
@@ -39,12 +40,13 @@
}

/// Create a new [InMemoryVectorStore] from documents and and their corresponding embeddings with ids.
pub fn from_documents_with_ids(documents: impl IntoIterator<Item = (impl ToString, D, OneOrMany<Embedding>)>) -> Self {
pub fn from_documents_with_ids(
documents: impl IntoIterator<Item = (impl ToString, D, OneOrMany<Embedding>)>,
) -> Self {
let mut store = HashMap::new();
documents.into_iter()
.for_each(|(i, doc, embeddings)| {
store.insert(i.to_string(), (doc, embeddings));
});
documents.into_iter().for_each(|(i, doc, embeddings)| {
store.insert(i.to_string(), (doc, embeddings));
});

Self { embeddings: store }
}
@@ -56,10 +58,9 @@
f: fn(&D) -> String,
) -> Self {
let mut store = HashMap::new();
documents.into_iter()
.for_each(|(doc, embeddings)| {
store.insert(f(&doc), (doc, embeddings));
});
documents.into_iter().for_each(|(doc, embeddings)| {
store.insert(f(&doc), (doc, embeddings));
});

Self { embeddings: store }
}
@@ -109,24 +110,25 @@
pub fn add_documents(
&mut self,
documents: impl IntoIterator<Item = (D, OneOrMany<Embedding>)>,
) -> () {
) {
let current_index = self.embeddings.len();
documents.into_iter()
documents
.into_iter()
.enumerate()
.for_each(|(index, (doc, embeddings))| {
self.embeddings.insert(format!("doc{}", index + current_index), (doc, embeddings));
self.embeddings
.insert(format!("doc{}", index + current_index), (doc, embeddings));
});
}

/// Add documents and their corresponding embeddings to the store with ids.
pub fn add_documents_with_ids(
&mut self,
documents: impl IntoIterator<Item = (impl ToString, D, OneOrMany<Embedding>)>,
) -> () {
documents.into_iter()
.for_each(|(id, doc, embeddings)| {
self.embeddings.insert(id.to_string(), (doc, embeddings));
});
) {
documents.into_iter().for_each(|(id, doc, embeddings)| {
self.embeddings.insert(id.to_string(), (doc, embeddings));
});
}

/// Add documents and their corresponding embeddings to the store.
@@ -135,7 +137,7 @@
&mut self,
documents: Vec<(D, OneOrMany<Embedding>)>,
f: fn(&D) -> String,
) -> () {
) {
for (doc, embeddings) in documents {
let id = f(&doc);
self.embeddings.insert(id, (doc, embeddings));
@@ -261,7 +263,7 @@
mod tests {
use std::cmp::Reverse;

use crate::{embeddings::embedding::Embedding, vector_store, OneOrMany};

Check failure on line 266 in rig-core/src/vector_store/in_memory_store.rs

GitHub Actions / stable / test

unused import: `vector_store`

use super::{InMemoryVectorStore, RankingItem};

@@ -308,8 +310,7 @@
),
]);

let mut store = vector_store.embeddings.into_iter()
.collect::<Vec<_>>();
let mut store = vector_store.embeddings.into_iter().collect::<Vec<_>>();
store.sort_by_key(|(id, _)| id.clone());

assert_eq!(
@@ -428,52 +429,52 @@
#[test]
fn test_multiple_embeddings() {
let vector_store = InMemoryVectorStore::from_documents_with_ids(vec![
(
"doc1",
"glarb-garb",
OneOrMany::many(vec![
Embedding {
document: "glarb-garb".to_string(),
vec: vec![0.1, 0.1, 0.5],
},
Embedding {
document: "don't-choose-me".to_string(),
vec: vec![-0.5, 0.9, 0.1],
},
])
.unwrap(),
),
(
"doc2",
"marble-marble",
OneOrMany::many(vec![
Embedding {
document: "marble-marble".to_string(),
vec: vec![0.7, -0.3, 0.0],
},
Embedding {
document: "sandwich".to_string(),
vec: vec![0.5, 0.5, -0.7],
},
])
.unwrap(),
),
(
"doc3",
"flumb-flumb",
OneOrMany::many(vec![
Embedding {
document: "flumb-flumb".to_string(),
vec: vec![0.3, 0.7, 0.1],
},
Embedding {
document: "banana".to_string(),
vec: vec![0.1, -0.5, -0.5],
},
])
.unwrap(),
),
]);
(
"doc1",
"glarb-garb",
OneOrMany::many(vec![
Embedding {
document: "glarb-garb".to_string(),
vec: vec![0.1, 0.1, 0.5],
},
Embedding {
document: "don't-choose-me".to_string(),
vec: vec![-0.5, 0.9, 0.1],
},
])
.unwrap(),
),
(
"doc2",
"marble-marble",
OneOrMany::many(vec![
Embedding {
document: "marble-marble".to_string(),
vec: vec![0.7, -0.3, 0.0],
},
Embedding {
document: "sandwich".to_string(),
vec: vec![0.5, 0.5, -0.7],
},
])
.unwrap(),
),
(
"doc3",
"flumb-flumb",
OneOrMany::many(vec![
Embedding {
document: "flumb-flumb".to_string(),
vec: vec![0.3, 0.7, 0.1],
},
Embedding {
document: "banana".to_string(),
vec: vec![0.1, -0.5, -0.5],
},
])
.unwrap(),
),
]);

let ranking = vector_store.vector_search(
&Embedding {
Loading