Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove DocumentEmbeddings from in memory vector store #53

Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ name: Lint & Test
on:
pull_request:
branches:
- main
- "**"
push:
branches:
- main
Expand Down
21 changes: 16 additions & 5 deletions rig-core/examples/calculator_chatbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use anyhow::Result;
use rig::{
cli_chatbot::cli_chatbot,
completion::ToolDefinition,
embeddings::EmbeddingsBuilder,
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
tool::{Tool, ToolEmbedding, ToolSet},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
vector_store::in_memory_store::InMemoryVectorStore,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
Expand Down Expand Up @@ -251,9 +251,20 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;

let mut store = InMemoryVectorStore::default();
store.add_documents(embeddings).await?;
let index = store.index(embedding_model);
let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?
.index(embedding_model);

// Create RAG agent with a single context prompt and a dynamic tool source
let calculator_rag = openai_client
Expand Down
25 changes: 16 additions & 9 deletions rig-core/examples/rag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use std::env;

use rig::{
completion::Prompt,
embeddings::EmbeddingsBuilder,
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
vector_store::in_memory_store::InMemoryVectorStore,
};

#[tokio::main]
Expand All @@ -15,20 +15,27 @@ async fn main() -> Result<(), anyhow::Error> {

let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

// Create vector store, compute embeddings and load them in the store
let mut vector_store = InMemoryVectorStore::default();

let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
.build()
.await?;

vector_store.add_documents(embeddings).await?;

// Create vector store index
let index = vector_store.index(embedding_model);
let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?
.index(embedding_model);

let rag_agent = openai_client.agent("gpt-4")
.preamble("
Expand Down
25 changes: 16 additions & 9 deletions rig-core/examples/rag_dynamic_tools.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use anyhow::Result;
use rig::{
completion::{Prompt, ToolDefinition},
embeddings::EmbeddingsBuilder,
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
tool::{Tool, ToolEmbedding, ToolSet},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore},
vector_store::in_memory_store::InMemoryVectorStore,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
Expand Down Expand Up @@ -150,9 +150,6 @@ async fn main() -> Result<(), anyhow::Error> {

let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

// Create vector store, compute tool embeddings and load them in the store
let mut vector_store = InMemoryVectorStore::default();

let toolset = ToolSet::builder()
.dynamic_tool(Add)
.dynamic_tool(Subtract)
Expand All @@ -163,10 +160,20 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;

vector_store.add_documents(embeddings).await?;

// Create vector store index
let index = vector_store.index(embedding_model);
let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?
.index(embedding_model);
cvauclair marked this conversation as resolved.
Show resolved Hide resolved

// Create RAG agent with a single context prompt and a dynamic tool source
let calculator_rag = openai_client
Expand Down
21 changes: 17 additions & 4 deletions rig-core/examples/vector_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::env;
use rig::{
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
};

#[tokio::main]
Expand All @@ -21,13 +21,26 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;

let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?;
let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?
.index(model);

let results = index
.top_n::<DocumentEmbeddings>("What is a linglingdong?", 1)
.top_n::<String>("What is a linglingdong?", 1)
.await?
.into_iter()
.map(|(score, id, doc)| (score, id, doc.document))
.map(|(score, id, doc)| (score, id, doc))
.collect::<Vec<_>>();

println!("Results: {:?}", results);
Expand Down
25 changes: 17 additions & 8 deletions rig-core/examples/vector_search_cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::env;
use rig::{
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::cohere::{Client, EMBED_ENGLISH_V3},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
};

#[tokio::main]
Expand All @@ -15,24 +15,33 @@ async fn main() -> Result<(), anyhow::Error> {
let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document");
let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query");

let mut vector_store = InMemoryVectorStore::default();

let embeddings = EmbeddingsBuilder::new(document_model)
.simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
.build()
.await?;

vector_store.add_documents(embeddings).await?;

let index = vector_store.index(search_model);
let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.collect(),
)?
.index(search_model);

let results = index
.top_n::<DocumentEmbeddings>("What is a linglingdong?", 1)
.top_n::<String>("What is a linglingdong?", 1)
.await?
.into_iter()
.map(|(score, id, doc)| (score, id, doc.document))
.map(|(score, id, doc)| (score, id, doc))
.collect::<Vec<_>>();

println!("Results: {:?}", results);
Expand Down
Loading
Loading