Skip to content

Commit

Permalink
refactor: remove DocumentEmbeddings from in memory vector store
Browse files Browse the repository at this point in the history
  • Loading branch information
marieaurore123 committed Oct 9, 2024
1 parent f5441db commit 5407772
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 270 deletions.
20 changes: 15 additions & 5 deletions rig-core/examples/calculator_chatbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use anyhow::Result;
use rig::{
cli_chatbot::cli_chatbot,
completion::ToolDefinition,
embeddings::EmbeddingsBuilder,
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
tool::{Tool, ToolEmbedding, ToolSet},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
vector_store::in_memory_store::InMemoryVectorStore,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
Expand Down Expand Up @@ -251,9 +251,19 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;

let mut store = InMemoryVectorStore::default();
store.add_documents(embeddings).await?;
let index = store.index(embedding_model);
let vector_store = InMemoryVectorStore::default().add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?;
let index = vector_store.index(embedding_model);

// Create RAG agent with a single context prompt and a dynamic tool source
let calculator_rag = openai_client
Expand Down
17 changes: 14 additions & 3 deletions rig-core/examples/rag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use std::env;

use rig::{
completion::Prompt,
embeddings::EmbeddingsBuilder,
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
vector_store::in_memory_store::InMemoryVectorStore,
};

#[tokio::main]
Expand All @@ -25,7 +25,18 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;

vector_store.add_documents(embeddings).await?;
let vector_store = vector_store.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?;

// Create vector store index
let index = vector_store.index(embedding_model);
Expand Down
20 changes: 14 additions & 6 deletions rig-core/examples/rag_dynamic_tools.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use anyhow::Result;
use rig::{
completion::{Prompt, ToolDefinition},
embeddings::EmbeddingsBuilder,
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
tool::{Tool, ToolEmbedding, ToolSet},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
vector_store::in_memory_store::InMemoryVectorStore,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
Expand Down Expand Up @@ -150,9 +150,6 @@ async fn main() -> Result<(), anyhow::Error> {

let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

// Create vector store, compute tool embeddings and load them in the store
let mut vector_store = InMemoryVectorStore::default();

let toolset = ToolSet::builder()
.dynamic_tool(Add)
.dynamic_tool(Subtract)
Expand All @@ -163,7 +160,18 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;

vector_store.add_documents(embeddings).await?;
let vector_store = InMemoryVectorStore::default().add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?;

// Create vector store index
let index = vector_store.index(embedding_model);
Expand Down
20 changes: 18 additions & 2 deletions rig-core/examples/vector_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use std::env;
use rig::{
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex},
vector_store::{
in_memory_store::{InMemoryVectorIndex, InMemoryVectorStore},

Check failure on line 7 in rig-core/examples/vector_search.rs

View workflow job for this annotation

GitHub Actions / stable / test

unused import: `InMemoryVectorIndex`
VectorStoreIndex,
},
};

#[tokio::main]
Expand All @@ -21,7 +24,20 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;

let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?;
let vector_store = InMemoryVectorStore::default().add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?;

let index = vector_store.index(model);

let results = index
.top_n::<DocumentEmbeddings>("What is a linglingdong?", 1)
Expand Down
17 changes: 13 additions & 4 deletions rig-core/examples/vector_search_cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::env;
use rig::{
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::cohere::{Client, EMBED_ENGLISH_V3},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
};

#[tokio::main]
Expand All @@ -15,16 +15,25 @@ async fn main() -> Result<(), anyhow::Error> {
let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document");
let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query");

let mut vector_store = InMemoryVectorStore::default();

let embeddings = EmbeddingsBuilder::new(document_model)
.simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
.build()
.await?;

vector_store.add_documents(embeddings).await?;
let vector_store = InMemoryVectorStore::default().add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?;

let index = vector_store.index(search_model);

Expand Down
Loading

0 comments on commit 5407772

Please sign in to comment.