Skip to content

Commit

Permalink
refactor: remove associated type on VectorStoreIndex trait
Browse files Browse the repository at this point in the history
  • Loading branch information
marieaurore123 committed Sep 24, 2024
1 parent e63d5a1 commit 4d28f61
Show file tree
Hide file tree
Showing 17 changed files with 221 additions and 228 deletions.
2 changes: 1 addition & 1 deletion rig-core/examples/calculator_chatbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ async fn main() -> Result<(), anyhow::Error> {
)
// Add a dynamic tool source with a sample rate of 1 (i.e.: only
// 1 additional tool will be added to prompts)
.dynamic_tools(4, index, toolset, "".to_string())
.dynamic_tools(4, index, toolset)
.build();

// Prompt the agent and print the response
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/rag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn main() -> Result<(), anyhow::Error> {
You are a dictionary assistant here to assist the user in understanding the meaning of words.
You will find additional non-standard word definitions that could be useful below.
")
.dynamic_context(1, index, "".to_string())
.dynamic_context(1, index)
.build();

// Prompt the agent and print the response
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/rag_dynamic_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ async fn main() -> Result<(), anyhow::Error> {
.preamble("You are a calculator here to help the user perform arithmetic operations.")
// Add a dynamic tool source with a sample rate of 1 (i.e.: only
// 1 additional tool will be added to prompts)
.dynamic_tools(1, index, toolset, "".to_string())
.dynamic_tools(1, index, toolset)
.build();

// Prompt the agent and print the response
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/vector_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async fn main() -> Result<(), anyhow::Error> {
let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?;

let results = index
.top_n_from_query("What is a linglingdong?", 1, ())
.top_n_from_query("What is a linglingdong?", 1)
.await?
.into_iter()
.map(|(score, doc)| (score, doc.id, doc.document))
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/vector_search_cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async fn main() -> Result<(), anyhow::Error> {
let index = vector_store.index(search_model);

let results = index
.top_n_from_query("What is a linglingdong?", 1, ())
.top_n_from_query("What is a linglingdong?", 1)
.await?
.into_iter()
.map(|(score, doc)| (score, doc.id, doc.document))
Expand Down
23 changes: 10 additions & 13 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ pub struct Agent<M: CompletionModel> {
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>, String)>,
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>, String)>,
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Actual tool implementations
pub tools: ToolSet,
}
Expand All @@ -167,10 +167,10 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let dynamic_context = stream::iter(self.dynamic_context.iter())
.then(|(num_sample, index, search_params)| async {
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_from_query(prompt, *num_sample, search_params)
.top_n_from_query(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, doc)| {
Expand All @@ -195,10 +195,10 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
.map_err(|e| CompletionError::RequestError(Box::new(e)))?;

let dynamic_tools = stream::iter(self.dynamic_tools.iter())
.then(|(num_sample, index, search_params)| async {
.then(|(num_sample, index)| async {
Ok::<_, VectorStoreError>(
index
.top_n_ids_from_query(prompt, *num_sample, search_params)
.top_n_ids_from_query(prompt, *num_sample)
.await?
.into_iter()
.map(|(_, doc)| doc)
Expand Down Expand Up @@ -296,9 +296,9 @@ pub struct AgentBuilder<M: CompletionModel> {
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>, String)>,
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>, String)>,
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Temperature of the model
temperature: Option<f64>,
/// Actual tool implementations
Expand Down Expand Up @@ -360,10 +360,9 @@ impl<M: CompletionModel> AgentBuilder<M> {
mut self,
sample: usize,
dynamic_context: impl VectorStoreIndexDyn + 'static,
search_params: String,
) -> Self {
self.dynamic_context
.push((sample, Box::new(dynamic_context), search_params));
.push((sample, Box::new(dynamic_context)));
self
}

Expand All @@ -374,10 +373,8 @@ impl<M: CompletionModel> AgentBuilder<M> {
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + 'static,
toolset: ToolSet,
search_params: String,
) -> Self {
self.dynamic_tools
.push((sample, Box::new(dynamic_tools), search_params));
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
self.tools.add_tools(toolset);
self
}
Expand Down
7 changes: 1 addition & 6 deletions rig-core/src/vector_store/in_memory_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,19 @@ impl<M: EmbeddingModel> InMemoryVectorIndex<M> {
}

impl<M: EmbeddingModel + std::marker::Sync> VectorStoreIndex for InMemoryVectorIndex<M> {
type SearchParams = ();

async fn top_n_from_query(
&self,
query: &str,
n: usize,
search_params: Self::SearchParams,
) -> Result<Vec<(f64, DocumentEmbeddings)>, VectorStoreError> {
let prompt_embedding = self.model.embed_document(query).await?;
self.top_n_from_embedding(&prompt_embedding, n, search_params)
.await
self.top_n_from_embedding(&prompt_embedding, n).await
}

async fn top_n_from_embedding(
&self,
query_embedding: &Embedding,
n: usize,
_search_params: Self::SearchParams,
) -> Result<Vec<(f64, DocumentEmbeddings)>, VectorStoreError> {
// Sort documents by best embedding distance
let mut docs: EmbeddingRanking = BinaryHeap::new();
Expand Down
53 changes: 8 additions & 45 deletions rig-core/src/vector_store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ pub trait VectorStore: Send + Sync {

/// Trait for vector store indexes
pub trait VectorStoreIndex: Send + Sync {
type SearchParams: for<'a> Deserialize<'a> + Send + Sync;

/// Get the top n documents based on the distance to the given embedding.
/// The distance is calculated as the cosine distance between the prompt and
/// the document embedding.
Expand All @@ -60,7 +58,6 @@ pub trait VectorStoreIndex: Send + Sync {
&self,
query: &str,
n: usize,
search_params: Self::SearchParams,
) -> impl std::future::Future<Output = Result<Vec<(f64, DocumentEmbeddings)>, VectorStoreError>> + Send;

/// Same as `top_n_from_query` but returns the documents without its embeddings.
Expand All @@ -69,10 +66,9 @@ pub trait VectorStoreIndex: Send + Sync {
&self,
query: &str,
n: usize,
search_params: Self::SearchParams,
) -> impl std::future::Future<Output = Result<Vec<(f64, T)>, VectorStoreError>> + Send {
async move {
let documents = self.top_n_from_query(query, n, search_params).await?;
let documents = self.top_n_from_query(query, n).await?;
Ok(documents
.into_iter()
.map(|(distance, doc)| (distance, serde_json::from_value(doc.document).unwrap()))
Expand All @@ -85,11 +81,10 @@ pub trait VectorStoreIndex: Send + Sync {
&self,
query: &str,
n: usize,
search_params: Self::SearchParams,
) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + Send
{
async move {
let documents = self.top_n_from_query(query, n, search_params).await?;
let documents = self.top_n_from_query(query, n).await?;
Ok(documents
.into_iter()
.map(|(distance, doc)| (distance, doc.id))
Expand All @@ -105,7 +100,6 @@ pub trait VectorStoreIndex: Send + Sync {
&self,
prompt_embedding: &Embedding,
n: usize,
search_params: Self::SearchParams,
) -> impl std::future::Future<Output = Result<Vec<(f64, DocumentEmbeddings)>, VectorStoreError>> + Send;

/// Same as `top_n_from_embedding` but returns the documents without its embeddings.
Expand All @@ -114,12 +108,9 @@ pub trait VectorStoreIndex: Send + Sync {
&self,
prompt_embedding: &Embedding,
n: usize,
search_params: Self::SearchParams,
) -> impl std::future::Future<Output = Result<Vec<(f64, T)>, VectorStoreError>> + Send {
async move {
let documents = self
.top_n_from_embedding(prompt_embedding, n, search_params)
.await?;
let documents = self.top_n_from_embedding(prompt_embedding, n).await?;
Ok(documents
.into_iter()
.map(|(distance, doc)| (distance, serde_json::from_value(doc.document).unwrap()))
Expand All @@ -132,13 +123,10 @@ pub trait VectorStoreIndex: Send + Sync {
&self,
prompt_embedding: &Embedding,
n: usize,
search_params: Self::SearchParams,
) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + Send
{
async move {
let documents = self
.top_n_from_embedding(prompt_embedding, n, search_params)
.await?;
let documents = self.top_n_from_embedding(prompt_embedding, n).await?;
Ok(documents
.into_iter()
.map(|(distance, doc)| (distance, doc.id))
Expand All @@ -152,17 +140,15 @@ pub trait VectorStoreIndexDyn: Send + Sync {
&'a self,
query: &'a str,
n: usize,
search_params: &'a str,
) -> BoxFuture<'a, Result<Vec<(f64, DocumentEmbeddings)>, VectorStoreError>>;

fn top_n_ids_from_query<'a>(
&'a self,
query: &'a str,
n: usize,
search_params: &'a str,
) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
Box::pin(async move {
let documents = self.top_n_from_query(query, n, search_params).await?;
let documents = self.top_n_from_query(query, n).await?;
Ok(documents
.into_iter()
.map(|(distance, doc)| (distance, doc.id))
Expand All @@ -174,19 +160,15 @@ pub trait VectorStoreIndexDyn: Send + Sync {
&'a self,
prompt_embedding: &'a Embedding,
n: usize,
search_params: &'a str,
) -> BoxFuture<'a, Result<Vec<(f64, DocumentEmbeddings)>, VectorStoreError>>;

fn top_n_ids_from_embedding<'a>(
&'a self,
prompt_embedding: &'a Embedding,
n: usize,
search_params: &'a str,
) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
Box::pin(async move {
let documents = self
.top_n_from_embedding(prompt_embedding, n, search_params)
.await?;
let documents = self.top_n_from_embedding(prompt_embedding, n).await?;
Ok(documents
.into_iter()
.map(|(distance, doc)| (distance, doc.id))
Expand All @@ -200,44 +182,26 @@ impl<I: VectorStoreIndex> VectorStoreIndexDyn for I {
&'a self,
query: &'a str,
n: usize,
search_params: &'a str,
) -> BoxFuture<'a, Result<Vec<(f64, DocumentEmbeddings)>, VectorStoreError>> {
Box::pin(async move {
match serde_json::from_str(search_params) {
Ok(search_params) => self.top_n_from_query(query, n, search_params).await,
Err(e) => Err(VectorStoreError::JsonError(e)),
}
})
Box::pin(async move { self.top_n_from_query(query, n).await })
}

fn top_n_from_embedding<'a>(
&'a self,
prompt_embedding: &'a Embedding,
n: usize,
search_params: &'a str,
) -> BoxFuture<'a, Result<Vec<(f64, DocumentEmbeddings)>, VectorStoreError>> {
Box::pin(async move {
match serde_json::from_str(search_params) {
Ok(search_params) => {
self.top_n_from_embedding(prompt_embedding, n, search_params)
.await
}
Err(e) => Err(VectorStoreError::JsonError(e)),
}
})
Box::pin(async move { self.top_n_from_embedding(prompt_embedding, n).await })
}
}

pub struct NoIndex;

impl VectorStoreIndex for NoIndex {
type SearchParams = ();

async fn top_n_from_query(
&self,
_query: &str,
_n: usize,
_search_params: Self::SearchParams,
) -> Result<Vec<(f64, DocumentEmbeddings)>, VectorStoreError> {
Ok(vec![])
}
Expand All @@ -246,7 +210,6 @@ impl VectorStoreIndex for NoIndex {
&self,
_prompt_embedding: &Embedding,
_n: usize,
_search_params: Self::SearchParams,
) -> Result<Vec<(f64, DocumentEmbeddings)>, VectorStoreError> {
Ok(vec![])
}
Expand Down
19 changes: 6 additions & 13 deletions rig-lancedb/examples/vector_search_local_ann.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ async fn main() -> Result<(), anyhow::Error> {
// Select the embedding model and generate our embeddings
let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002);

let search_params = SearchParams::default().distance_type(DistanceType::Cosine);

// Initialize LanceDB locally.
let db = lancedb::connect("data/lancedb-store").execute().await?;
let mut vector_store = LanceDbVectorStore::new(&db, &model).await?;
let mut vector_store = LanceDbVectorStore::new(&db, &model, &search_params).await?;

// Generate test data for RAG demo
let agent = openai_client
Expand Down Expand Up @@ -52,24 +54,15 @@ async fn main() -> Result<(), anyhow::Error> {
vector_store
.create_index(lancedb::index::Index::IvfPq(
IvfPqIndexBuilder::default()
// This overrides the default distance type of L2
// This overrides the default distance type of L2.
// Needs to be the same distance type as the one used in search params.
.distance_type(DistanceType::Cosine),
))
.await?;

// Query the index
let results = vector_store
.top_n_from_query(
"My boss says I zindle too much, what does that mean?",
1,
&serde_json::to_string(&SearchParams::new(
Some(DistanceType::Cosine),
None,
None,
None,
None,
))?,
)
.top_n_from_query("My boss says I zindle too much, what does that mean?", 1)
.await?
.into_iter()
.map(|(score, doc)| (score, doc.id, doc.document))
Expand Down
8 changes: 2 additions & 6 deletions rig-lancedb/examples/vector_search_local_enn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async fn main() -> Result<(), anyhow::Error> {

// Initialize LanceDB locally.
let db = lancedb::connect("data/lancedb-store").execute().await?;
let mut vector_store = LanceDbVectorStore::new(&db, &model).await?;
let mut vector_store = LanceDbVectorStore::new(&db, &model, &SearchParams::default()).await?;

let embeddings = EmbeddingsBuilder::new(model.clone())
.simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.")
Expand All @@ -32,11 +32,7 @@ async fn main() -> Result<(), anyhow::Error> {

// Query the index
let results = vector_store
.top_n_from_query(
"My boss says I zindle too much, what does that mean?",
1,
&serde_json::to_string(&SearchParams::new(None, None, None, None, None))?,
)
.top_n_from_query("My boss says I zindle too much, what does that mean?", 1)
.await?
.into_iter()
.map(|(score, doc)| (score, doc.id, doc.document))
Expand Down
Loading

0 comments on commit 4d28f61

Please sign in to comment.