Skip to content

Commit

Permalink
Merge pull request #53 from 0xPlaygrounds/refactor(vector-store)/in-m…
Browse files Browse the repository at this point in the history
…emeory-vector-store

refactor: remove DocumentEmbeddings from in memory vector store
  • Loading branch information
marieaurore123 authored Oct 15, 2024
2 parents 1353be3 + 8c993dd commit ef00b38
Show file tree
Hide file tree
Showing 11 changed files with 291 additions and 292 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ name: Lint & Test
on:
pull_request:
branches:
- main
- "**"
push:
branches:
- main
Expand Down
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 16 additions & 5 deletions rig-core/examples/calculator_chatbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use anyhow::Result;
use rig::{
cli_chatbot::cli_chatbot,
completion::ToolDefinition,
embeddings::EmbeddingsBuilder,
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
tool::{Tool, ToolEmbedding, ToolSet},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
vector_store::in_memory_store::InMemoryVectorStore,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
Expand Down Expand Up @@ -251,9 +251,20 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;

let mut store = InMemoryVectorStore::default();
store.add_documents(embeddings).await?;
let index = store.index(embedding_model);
let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?
.index(embedding_model);

// Create RAG agent with a single context prompt and a dynamic tool source
let calculator_rag = openai_client
Expand Down
25 changes: 16 additions & 9 deletions rig-core/examples/rag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use std::env;

use rig::{
completion::Prompt,
embeddings::EmbeddingsBuilder,
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
vector_store::in_memory_store::InMemoryVectorStore,
};

#[tokio::main]
Expand All @@ -15,20 +15,27 @@ async fn main() -> Result<(), anyhow::Error> {

let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

// Create vector store, compute embeddings and load them in the store
let mut vector_store = InMemoryVectorStore::default();

let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
.build()
.await?;

vector_store.add_documents(embeddings).await?;

// Create vector store index
let index = vector_store.index(embedding_model);
let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?
.index(embedding_model);

let rag_agent = openai_client.agent("gpt-4")
.preamble("
Expand Down
25 changes: 16 additions & 9 deletions rig-core/examples/rag_dynamic_tools.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use anyhow::Result;
use rig::{
completion::{Prompt, ToolDefinition},
embeddings::EmbeddingsBuilder,
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
tool::{Tool, ToolEmbedding, ToolSet},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
vector_store::in_memory_store::InMemoryVectorStore,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
Expand Down Expand Up @@ -150,9 +150,6 @@ async fn main() -> Result<(), anyhow::Error> {

let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

// Create vector store, compute tool embeddings and load them in the store
let mut vector_store = InMemoryVectorStore::default();

let toolset = ToolSet::builder()
.dynamic_tool(Add)
.dynamic_tool(Subtract)
Expand All @@ -163,10 +160,20 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;

vector_store.add_documents(embeddings).await?;

// Create vector store index
let index = vector_store.index(embedding_model);
let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?
.index(embedding_model);

// Create RAG agent with a single context prompt and a dynamic tool source
let calculator_rag = openai_client
Expand Down
21 changes: 17 additions & 4 deletions rig-core/examples/vector_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::env;
use rig::{
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
};

#[tokio::main]
Expand All @@ -21,13 +21,26 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;

let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?;
let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?
.index(model);

let results = index
.top_n::<DocumentEmbeddings>("What is a linglingdong?", 1)
.top_n::<String>("What is a linglingdong?", 1)
.await?
.into_iter()
.map(|(score, id, doc)| (score, id, doc.document))
.map(|(score, id, doc)| (score, id, doc))
.collect::<Vec<_>>();

println!("Results: {:?}", results);
Expand Down
25 changes: 17 additions & 8 deletions rig-core/examples/vector_search_cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::env;
use rig::{
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::cohere::{Client, EMBED_ENGLISH_V3},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
};

#[tokio::main]
Expand All @@ -15,24 +15,33 @@ async fn main() -> Result<(), anyhow::Error> {
let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document");
let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query");

let mut vector_store = InMemoryVectorStore::default();

let embeddings = EmbeddingsBuilder::new(document_model)
.simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
.build()
.await?;

vector_store.add_documents(embeddings).await?;

let index = vector_store.index(search_model);
let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?
.index(search_model);

let results = index
.top_n::<DocumentEmbeddings>("What is a linglingdong?", 1)
.top_n::<String>("What is a linglingdong?", 1)
.await?
.into_iter()
.map(|(score, id, doc)| (score, id, doc.document))
.map(|(score, id, doc)| (score, id, doc))
.collect::<Vec<_>>();

println!("Results: {:?}", results);
Expand Down
Loading

0 comments on commit ef00b38

Please sign in to comment.