diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index bba21914..c096857d 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -272,7 +272,7 @@ async fn main() -> Result<(), anyhow::Error> { ) // Add a dynamic tool source with a sample rate of 1 (i.e.: only // 1 additional tool will be added to prompts) - .dynamic_tools(4, index, toolset, "".to_string()) + .dynamic_tools(4, index, toolset) .build(); // Prompt the agent and print the response diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 3a390b2a..fae0b91d 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -35,7 +35,7 @@ async fn main() -> Result<(), anyhow::Error> { You are a dictionary assistant here to assist the user in understanding the meaning of words. You will find additional non-standard word definitions that could be useful below. ") - .dynamic_context(1, index, "".to_string()) + .dynamic_context(1, index) .build(); // Prompt the agent and print the response diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 777c75ce..cb7a955d 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -174,7 +174,7 @@ async fn main() -> Result<(), anyhow::Error> { .preamble("You are a calculator here to help the user perform arithmetic operations.") // Add a dynamic tool source with a sample rate of 1 (i.e.: only // 1 additional tool will be added to prompts) - .dynamic_tools(1, index, toolset, "".to_string()) + .dynamic_tools(1, index, toolset) .build(); // Prompt the agent and print the response diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 4e2e5870..04ba2ab3 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -24,7 +24,7 @@ async fn main() -> Result<(), anyhow::Error> { let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?; let results = index - .top_n_from_query("What is a linglingdong?", 1, ()) + .top_n_from_query("What is a linglingdong?", 1) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 579a298d..a8c2d163 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -31,7 +31,7 @@ async fn main() -> Result<(), anyhow::Error> { let index = vector_store.index(search_model); let results = index - .top_n_from_query("What is a linglingdong?", 1, ()) + .top_n_from_query("What is a linglingdong?", 1) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 0647815a..8648a9eb 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -153,9 +153,9 @@ pub struct Agent { /// Additional parameters to be passed to the model additional_params: Option, /// List of vector store, with the sample number - dynamic_context: Vec<(usize, Box, String)>, + dynamic_context: Vec<(usize, Box)>, /// Dynamic tools - dynamic_tools: Vec<(usize, Box, String)>, + dynamic_tools: Vec<(usize, Box)>, /// Actual tool implementations pub tools: ToolSet, } @@ -167,10 +167,10 @@ impl Completion for Agent { chat_history: Vec, ) -> Result, CompletionError> { let dynamic_context = stream::iter(self.dynamic_context.iter()) - .then(|(num_sample, index, search_params)| async { + .then(|(num_sample, index)| async { Ok::<_, VectorStoreError>( index - .top_n_from_query(prompt, *num_sample, search_params) + .top_n_from_query(prompt, *num_sample) .await? .into_iter() .map(|(_, doc)| { @@ -195,10 +195,10 @@ impl Completion for Agent { .map_err(|e| CompletionError::RequestError(Box::new(e)))?; let dynamic_tools = stream::iter(self.dynamic_tools.iter()) - .then(|(num_sample, index, search_params)| async { + .then(|(num_sample, index)| async { Ok::<_, VectorStoreError>( index - .top_n_ids_from_query(prompt, *num_sample, search_params) + .top_n_ids_from_query(prompt, *num_sample) .await? .into_iter() .map(|(_, doc)| doc) @@ -296,9 +296,9 @@ pub struct AgentBuilder { /// Additional parameters to be passed to the model additional_params: Option, /// List of vector store, with the sample number - dynamic_context: Vec<(usize, Box, String)>, + dynamic_context: Vec<(usize, Box)>, /// Dynamic tools - dynamic_tools: Vec<(usize, Box, String)>, + dynamic_tools: Vec<(usize, Box)>, /// Temperature of the model temperature: Option, /// Actual tool implementations @@ -360,10 +360,9 @@ impl AgentBuilder { mut self, sample: usize, dynamic_context: impl VectorStoreIndexDyn + 'static, - search_params: String, ) -> Self { self.dynamic_context - .push((sample, Box::new(dynamic_context), search_params)); + .push((sample, Box::new(dynamic_context))); self } @@ -374,10 +373,8 @@ impl AgentBuilder { sample: usize, dynamic_tools: impl VectorStoreIndexDyn + 'static, toolset: ToolSet, - search_params: String, ) -> Self { - self.dynamic_tools - .push((sample, Box::new(dynamic_tools), search_params)); + self.dynamic_tools.push((sample, Box::new(dynamic_tools))); self.tools.add_tools(toolset); self } diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index c2239a23..02c19cf8 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -198,24 +198,19 @@ impl InMemoryVectorIndex { } impl VectorStoreIndex for InMemoryVectorIndex { - type SearchParams = (); - async fn top_n_from_query( &self, query: &str, n: usize, - search_params: Self::SearchParams, ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; - self.top_n_from_embedding(&prompt_embedding, n, search_params) - .await + self.top_n_from_embedding(&prompt_embedding, n).await } async fn top_n_from_embedding( &self, query_embedding: &Embedding, n: usize, - _search_params: Self::SearchParams, ) -> Result, VectorStoreError> { // Sort documents by best embedding distance let mut docs: EmbeddingRanking = BinaryHeap::new(); diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 0e2d9983..c042f0c3 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -50,8 +50,6 @@ pub trait VectorStore: Send + Sync { /// Trait for vector store indexes pub trait VectorStoreIndex: Send + Sync { - type SearchParams: for<'a> Deserialize<'a> + Send + Sync; - /// Get the top n documents based on the distance to the given embedding. /// The distance is calculated as the cosine distance between the prompt and /// the document embedding. @@ -60,7 +58,6 @@ pub trait VectorStoreIndex: Send + Sync { &self, query: &str, n: usize, - search_params: Self::SearchParams, ) -> impl std::future::Future, VectorStoreError>> + Send; /// Same as `top_n_from_query` but returns the documents without its embeddings. @@ -69,10 +66,9 @@ pub trait VectorStoreIndex: Send + Sync { &self, query: &str, n: usize, - search_params: Self::SearchParams, ) -> impl std::future::Future, VectorStoreError>> + Send { async move { - let documents = self.top_n_from_query(query, n, search_params).await?; + let documents = self.top_n_from_query(query, n).await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, serde_json::from_value(doc.document).unwrap())) @@ -85,11 +81,10 @@ pub trait VectorStoreIndex: Send + Sync { &self, query: &str, n: usize, - search_params: Self::SearchParams, ) -> impl std::future::Future, VectorStoreError>> + Send { async move { - let documents = self.top_n_from_query(query, n, search_params).await?; + let documents = self.top_n_from_query(query, n).await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, doc.id)) @@ -105,7 +100,6 @@ pub trait VectorStoreIndex: Send + Sync { &self, prompt_embedding: &Embedding, n: usize, - search_params: Self::SearchParams, ) -> impl std::future::Future, VectorStoreError>> + Send; /// Same as `top_n_from_embedding` but returns the documents without its embeddings. @@ -114,12 +108,9 @@ pub trait VectorStoreIndex: Send + Sync { &self, prompt_embedding: &Embedding, n: usize, - search_params: Self::SearchParams, ) -> impl std::future::Future, VectorStoreError>> + Send { async move { - let documents = self - .top_n_from_embedding(prompt_embedding, n, search_params) - .await?; + let documents = self.top_n_from_embedding(prompt_embedding, n).await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, serde_json::from_value(doc.document).unwrap())) @@ -132,13 +123,10 @@ pub trait VectorStoreIndex: Send + Sync { &self, prompt_embedding: &Embedding, n: usize, - search_params: Self::SearchParams, ) -> impl std::future::Future, VectorStoreError>> + Send { async move { - let documents = self - .top_n_from_embedding(prompt_embedding, n, search_params) - .await?; + let documents = self.top_n_from_embedding(prompt_embedding, n).await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, doc.id)) @@ -152,17 +140,15 @@ pub trait VectorStoreIndexDyn: Send + Sync { &'a self, query: &'a str, n: usize, - search_params: &'a str, ) -> BoxFuture<'a, Result, VectorStoreError>>; fn top_n_ids_from_query<'a>( &'a self, query: &'a str, n: usize, - search_params: &'a str, ) -> BoxFuture<'a, Result, VectorStoreError>> { Box::pin(async move { - let documents = self.top_n_from_query(query, n, search_params).await?; + let documents = self.top_n_from_query(query, n).await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, doc.id)) @@ -174,19 +160,15 @@ pub trait VectorStoreIndexDyn: Send + Sync { &'a self, prompt_embedding: &'a Embedding, n: usize, - search_params: &'a str, ) -> BoxFuture<'a, Result, VectorStoreError>>; fn top_n_ids_from_embedding<'a>( &'a self, prompt_embedding: &'a Embedding, n: usize, - search_params: &'a str, ) -> BoxFuture<'a, Result, VectorStoreError>> { Box::pin(async move { - let documents = self - .top_n_from_embedding(prompt_embedding, n, search_params) - .await?; + let documents = self.top_n_from_embedding(prompt_embedding, n).await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, doc.id)) @@ -200,44 +182,26 @@ impl VectorStoreIndexDyn for I { &'a self, query: &'a str, n: usize, - search_params: &'a str, ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(async move { - match serde_json::from_str(search_params) { - Ok(search_params) => self.top_n_from_query(query, n, search_params).await, - Err(e) => Err(VectorStoreError::JsonError(e)), - } - }) + Box::pin(async move { self.top_n_from_query(query, n).await }) } fn top_n_from_embedding<'a>( &'a self, prompt_embedding: &'a Embedding, n: usize, - search_params: &'a str, ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(async move { - match serde_json::from_str(search_params) { - Ok(search_params) => { - self.top_n_from_embedding(prompt_embedding, n, search_params) - .await - } - Err(e) => Err(VectorStoreError::JsonError(e)), - } - }) + Box::pin(async move { self.top_n_from_embedding(prompt_embedding, n).await }) } } pub struct NoIndex; impl VectorStoreIndex for NoIndex { - type SearchParams = (); - async fn top_n_from_query( &self, _query: &str, _n: usize, - _search_params: Self::SearchParams, ) -> Result, VectorStoreError> { Ok(vec![]) } @@ -246,7 +210,6 @@ impl VectorStoreIndex for NoIndex { &self, _prompt_embedding: &Embedding, _n: usize, - _search_params: Self::SearchParams, ) -> Result, VectorStoreError> { Ok(vec![]) } diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 3317a8df..4a6a518a 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -18,9 +18,11 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + let search_params = SearchParams::default().distance_type(DistanceType::Cosine); + // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; - let mut vector_store = LanceDbVectorStore::new(&db, &model).await?; + let mut vector_store = LanceDbVectorStore::new(&db, &model, &search_params).await?; // Generate test data for RAG demo let agent = openai_client @@ -52,24 +54,15 @@ async fn main() -> Result<(), anyhow::Error> { vector_store .create_index(lancedb::index::Index::IvfPq( IvfPqIndexBuilder::default() - // This overrides the default distance type of L2 + // This overrides the default distance type of L2. + // Needs to be the same distance type as the one used in search params. .distance_type(DistanceType::Cosine), )) .await?; // Query the index let results = vector_store - .top_n_from_query( - "My boss says I zindle too much, what does that mean?", - 1, - &serde_json::to_string(&SearchParams::new( - Some(DistanceType::Cosine), - None, - None, - None, - None, - ))?, - ) + .top_n_from_query("My boss says I zindle too much, what does that mean?", 1) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 125ea73c..94d6aaf8 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -18,7 +18,7 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; - let mut vector_store = LanceDbVectorStore::new(&db, &model).await?; + let mut vector_store = LanceDbVectorStore::new(&db, &model, &SearchParams::default()).await?; let embeddings = EmbeddingsBuilder::new(model.clone()) .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") @@ -32,11 +32,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n_from_query( - "My boss says I zindle too much, what does that mean?", - 1, - &serde_json::to_string(&SearchParams::new(None, None, None, None, None))?, - ) + .top_n_from_query("My boss says I zindle too much, what does that mean?", 1) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 6ff25d94..1ec4f94a 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -21,13 +21,15 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + let search_params = SearchParams::default().distance_type(DistanceType::Cosine); + // Initialize LanceDB on S3. // Note: see below docs for more options and IAM permission required to read/write to S3. // https://lancedb.github.io/lancedb/guides/storage/#aws-s3 let db = lancedb::connect("s3://lancedb-test-829666124233") .execute() .await?; - let mut vector_store = LanceDbVectorStore::new(&db, &model).await?; + let mut vector_store = LanceDbVectorStore::new(&db, &model, &search_params).await?; // Generate test data for RAG demo let agent = openai_client @@ -59,25 +61,15 @@ async fn main() -> Result<(), anyhow::Error> { vector_store .create_index(lancedb::index::Index::IvfPq( IvfPqIndexBuilder::default() - // This overrides the default distance type of L2 + // This overrides the default distance type of L2. + // Needs to be the same distance type as the one used in search params. .distance_type(DistanceType::Cosine), )) .await?; // Query the index let results = vector_store - .top_n_from_query( - "My boss says I zindle too much, what does that mean?", - 1, - &serde_json::to_string(&SearchParams::new( - // Important: use the same same distance type that was used to train the index. - Some(DistanceType::Cosine), - None, - None, - None, - None, - ))?, - ) + .top_n_from_query("My boss says I zindle too much, what does that mean?", 1) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index f81a643c..a8463d2b 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -17,16 +17,90 @@ use utils::{Insert, Query}; mod table_schemas; mod utils; +fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError { + VectorStoreError::DatastoreError(Box::new(e)) +} + +fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { + VectorStoreError::JsonError(e) +} + pub struct LanceDbVectorStore { + /// Defines which model is used to generate embeddings for the vector store model: M, + /// Table containing documents only document_table: lancedb::Table, + /// Table containing embeddings only. + /// Foreign key references the document in document table. embedding_table: lancedb::Table, + /// Vector search params that are used during vector search operations. + search_params: SearchParams, +} + +/// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information. +#[derive(Deserialize, Serialize, Debug, Clone)] +pub enum SearchType { + // Flat search, also called ENN or kNN. + Flat, + /// Approximal Nearest Neighbor search, also called ANN. + Approximate, +} + +#[derive(Deserialize, Serialize, Debug, Clone, Default)] +pub struct SearchParams { + /// Always set the distance_type to match the value used to train the index + /// By default, set to L2 + distance_type: Option, + /// By default, ANN will be used if there is an index on the table. + /// By default, kNN will be used if there is NO index on the table. + /// To use defaults, set to None. + search_type: Option, + /// Set this value only when search type is ANN. + /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information + nprobes: Option, + /// Set this value only when search type is ANN. + /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information + refine_factor: Option, + /// If set to true, filtering will happen after the vector search instead of before + /// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information + post_filter: Option, +} + +impl SearchParams { + pub fn distance_type(mut self, distance_type: DistanceType) -> Self { + self.distance_type = Some(distance_type); + self + } + + pub fn search_type(mut self, search_type: SearchType) -> Self { + self.search_type = Some(search_type); + self + } + + pub fn nprobes(mut self, nprobes: usize) -> Self { + self.nprobes = Some(nprobes); + self + } + + pub fn refine_factor(mut self, refine_factor: u32) -> Self { + self.refine_factor = Some(refine_factor); + self + } + + pub fn post_filter(mut self, post_filter: bool) -> Self { + self.post_filter = Some(post_filter); + self + } } impl LanceDbVectorStore { /// Note: Tables are created inside the new function rather than created outside and passed as reference to new function. /// This is because a specific schema needs to be enforced on the tables and this is done at creation time. - pub async fn new(db: &lancedb::Connection, model: &M) -> Result { + pub async fn new( + db: &lancedb::Connection, + model: &M, + search_params: &SearchParams, + ) -> Result { let document_table = db .create_empty_table("documents", Arc::new(Self::document_schema())) .execute() @@ -44,9 +118,11 @@ impl LanceDbVectorStore { document_table, embedding_table, model: model.clone(), + search_params: search_params.clone(), }) } + /// Schema of records in document table. fn document_schema() -> Schema { Schema::new(Fields::from(vec![ Field::new("id", DataType::Utf8, false), @@ -54,6 +130,8 @@ impl LanceDbVectorStore { ])) } + /// Schema of records in embeddings table. + /// Every embedding vector in the table must have the same size. fn embedding_schema(dimension: i32) -> Schema { Schema::new(Fields::from(vec![ Field::new("id", DataType::Utf8, false), @@ -70,6 +148,7 @@ impl LanceDbVectorStore { ])) } + /// Define index on document table `id` field for search optimization. pub async fn create_document_index(&self, index: Index) -> Result<(), lancedb::Error> { self.document_table .create_index(&["id"], index) @@ -77,6 +156,7 @@ impl LanceDbVectorStore { .await } + /// Define index on embedding table `id` and `document_id` fields for search optimization. pub async fn create_embedding_index(&self, index: Index) -> Result<(), lancedb::Error> { self.embedding_table .create_index(&["id", "document_id"], index) @@ -84,6 +164,7 @@ impl LanceDbVectorStore { .await } + /// Define index on embedding table `embedding` fields for vector search optimization. pub async fn create_index(&self, index: Index) -> Result<(), lancedb::Error> { self.embedding_table .create_index(&["embedding"], index) @@ -94,14 +175,6 @@ impl LanceDbVectorStore { } } -fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError { - VectorStoreError::DatastoreError(Box::new(e)) -} - -fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { - VectorStoreError::JsonError(e) -} - impl VectorStore for LanceDbVectorStore { type Q = lancedb::query::Query; @@ -137,14 +210,14 @@ impl VectorStore for LanceDbVector let documents: DocumentRecords = self .document_table .query() - .only_if(format!("id = {id}")) + .only_if(format!("id = '{id}'")) .execute_query() .await?; let embeddings: EmbeddingRecordsBatch = self .embedding_table .query() - .only_if(format!("document_id = {id}")) + .only_if(format!("document_id = '{id}'")) .execute_query() .await?; @@ -158,7 +231,7 @@ impl VectorStore for LanceDbVector let documents: DocumentRecords = self .document_table .query() - .only_if(format!("id = {id}")) + .only_if(format!("id = '{id}'")) .execute_query() .await?; @@ -180,7 +253,14 @@ impl VectorStore for LanceDbVector let embeddings: EmbeddingRecordsBatch = self .embedding_table .query() - .only_if(format!("document_id IN [{}]", documents.ids().join(","))) + .only_if(format!( + "document_id IN ({})", + documents + .ids() + .map(|id| format!("'{id}'")) + .collect::>() + .join(",") + )) .execute_query() .await?; @@ -188,84 +268,34 @@ impl VectorStore for LanceDbVector } } -/// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information. -#[derive(Deserialize, Serialize, Debug, Clone)] -pub enum SearchType { - // Flat search, also called ENN or kNN. - Flat, - /// Approximal Nearest Neighbor search, also called ANN. - Approximate, -} - -#[derive(Deserialize, Serialize, Debug, Clone)] -pub struct SearchParams { - /// Always set the distance_type to match the value used to train the index - /// By default, set to L2 - distance_type: Option, - /// By default, ANN will be used if there is an index on the table. - /// By default, kNN will be used if there is NO index on the table. - /// To use defaults, set to None. - search_type: Option, - /// Set this value only when search type is ANN. - /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information - nprobes: Option, - /// Set this value only when search type is ANN. - /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information - refine_factor: Option, - /// If set to true, filtering will happen after the vector search instead of before - /// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information - post_filter: Option, -} - -impl SearchParams { - pub fn new( - distance_type: Option, - search_type: Option, - nprobes: Option, - refine_factor: Option, - post_filter: Option, - ) -> Self { - Self { - distance_type, - search_type, - nprobes, - refine_factor, - post_filter, - } - } -} - impl VectorStoreIndex for LanceDbVectorStore { async fn top_n_from_query( &self, query: &str, n: usize, - search_params: Self::SearchParams, ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; - self.top_n_from_embedding(&prompt_embedding, n, search_params) - .await + self.top_n_from_embedding(&prompt_embedding, n).await } async fn top_n_from_embedding( &self, prompt_embedding: &rig::embeddings::Embedding, n: usize, - search_params: Self::SearchParams, ) -> Result, VectorStoreError> { + let query = self + .embedding_table + .vector_search(prompt_embedding.vec.clone()) + .map_err(lancedb_to_rig_error)? + .limit(n); + let SearchParams { distance_type, search_type, nprobes, refine_factor, post_filter, - } = search_params.clone(); - - let query = self - .embedding_table - .vector_search(prompt_embedding.vec.clone()) - .map_err(lancedb_to_rig_error)? - .limit(n); + } = self.search_params.clone(); if let Some(distance_type) = distance_type { query.clone().distance_type(distance_type); @@ -317,6 +347,4 @@ impl VectorStoreIndex for LanceDbV }) .collect()) } - - type SearchParams = SearchParams; } diff --git a/rig-lancedb/src/table_schemas/document.rs b/rig-lancedb/src/table_schemas/document.rs index 56b19bef..384eb4bf 100644 --- a/rig-lancedb/src/table_schemas/document.rs +++ b/rig-lancedb/src/table_schemas/document.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use arrow_array::{ArrayRef, RecordBatch, StringArray}; +use arrow_array::{types::Utf8Type, ArrayRef, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::ArrowError; use rig::{embeddings::DocumentEmbeddings, vector_store::VectorStoreError}; -use crate::utils::DeserializeArrow; +use crate::utils::DeserializeByteArray; /// Schema of `documents` table in LanceDB defined as a struct. #[derive(Clone, Debug)] @@ -30,12 +30,12 @@ impl DocumentRecords { self.0.extend(records); } - fn documents(&self) -> Vec { - self.as_iter().map(|doc| doc.document.clone()).collect() + fn documents(&self) -> impl Iterator + '_ { + self.as_iter().map(|doc| doc.document.clone()) } - pub fn ids(&self) -> Vec { - self.as_iter().map(|doc| doc.id.clone()).collect() + pub fn ids(&self) -> impl Iterator + '_ { + self.as_iter().map(|doc| doc.id.clone()) } pub fn as_iter(&self) -> impl Iterator { @@ -97,8 +97,11 @@ impl TryFrom for DocumentRecords { type Error = ArrowError; fn try_from(record_batch: RecordBatch) -> Result { - let ids = record_batch.to_str(0)?; - let documents = record_batch.to_str(1)?; + let binding_0 = record_batch.column(0); + let ids = binding_0.to_str::()?; + + let binding_1 = record_batch.column(1); + let documents = binding_1.to_str::()?; Ok(DocumentRecords( ids.into_iter() diff --git a/rig-lancedb/src/table_schemas/embedding.rs b/rig-lancedb/src/table_schemas/embedding.rs index c73d4e53..7f74dd12 100644 --- a/rig-lancedb/src/table_schemas/embedding.rs +++ b/rig-lancedb/src/table_schemas/embedding.rs @@ -2,13 +2,13 @@ use std::{collections::HashMap, sync::Arc}; use arrow_array::{ builder::{FixedSizeListBuilder, Float64Builder}, - types::{Float32Type, Float64Type}, + types::{Float32Type, Float64Type, Utf8Type}, ArrayRef, RecordBatch, StringArray, }; use lancedb::arrow::arrow_schema::ArrowError; use rig::{embeddings::DocumentEmbeddings, vector_store::VectorStoreError}; -use crate::utils::{DeserializeArrow, DeserializePrimitiveArray}; +use crate::utils::{DeserializeByteArray, DeserializeListArray, DeserializePrimitiveArray}; /// Data format in the LanceDB table `embeddings` #[derive(Clone, Debug, PartialEq)] @@ -158,10 +158,16 @@ impl TryFrom for EmbeddingRecords { type Error = ArrowError; fn try_from(record_batch: RecordBatch) -> Result { - let ids = record_batch.to_str(0)?; - let document_ids = record_batch.to_str(1)?; - let contents = record_batch.to_str(2)?; - let embeddings = record_batch.to_float_list::(3)?; + let binding_0 = record_batch.column(0); + let ids = binding_0.to_str::()?; + + let binding_1 = record_batch.column(1); + let document_ids = binding_1.to_str::()?; + + let binding_2 = record_batch.column(2); + let contents = binding_2.to_str::()?; + + let embeddings = record_batch.column(3).to_float_list::()?; // There is a `_distance` field in the response if the executed query was a VectorQuery // Otherwise, for normal queries, the `_distance` field is not present in the response. diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index a8ef758d..bf8874e2 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -1,8 +1,8 @@ use std::sync::Arc; use arrow_array::{ - Array, ArrowPrimitiveType, FixedSizeListArray, PrimitiveArray, RecordBatch, - RecordBatchIterator, StringArray, + types::ByteArrayType, Array, ArrowPrimitiveType, FixedSizeListArray, GenericByteArray, + PrimitiveArray, RecordBatch, RecordBatchIterator, }; use futures::TryStreamExt; use lancedb::{ @@ -13,6 +13,7 @@ use rig::vector_store::VectorStoreError; use crate::lancedb_to_rig_error; +/// Trait used to "deserialize" an arrow_array::Array as as list of primitive objects. pub trait DeserializePrimitiveArray { fn to_float( &self, @@ -32,44 +33,39 @@ impl DeserializePrimitiveArray for &Arc { } } -/// Trait used to "deserialize" a column of a RecordBatch object into a list o primitive types -pub trait DeserializeArrow { - /// Define the column number that contains strings, i. - /// For each item in the column, convert it to a string and collect the result in a vector of strings. - fn to_str(&self, i: usize) -> Result, ArrowError>; - /// Define the column number that contains the list of floats, i. - /// For each item in the column, convert it to a list and for each item in the list, convert it to a float. - /// Collect the result as a vector of vectors of floats. - fn to_float_list( - &self, - i: usize, - ) -> Result::Native>>, ArrowError>; +/// Trait used to "deserialize" an arrow_array::Array as as list of byte objects. +pub trait DeserializeByteArray { + fn to_str(&self) -> Result::Native>, ArrowError>; } -impl DeserializeArrow for RecordBatch { - fn to_str(&self, i: usize) -> Result, ArrowError> { - let column = self.column(i); - match column.as_any().downcast_ref::() { - Some(str_array) => Ok((0..str_array.len()) - .map(|j| str_array.value(j)) - .collect::>()), +impl DeserializeByteArray for &Arc { + fn to_str(&self) -> Result::Native>, ArrowError> { + match self.as_any().downcast_ref::>() { + Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), None => Err(ArrowError::CastError(format!( - "Can't cast column {i} to string array" + "Can't cast array: {self:?} to float array" ))), } } +} + +/// Trait used to "deserialize" an arrow_array::Array as as list of lists of primitive objects. +pub trait DeserializeListArray { + fn to_float_list( + &self, + ) -> Result::Native>>, ArrowError>; +} +impl DeserializeListArray for &Arc { fn to_float_list( &self, - i: usize, ) -> Result::Native>>, ArrowError> { - let column = self.column(i); - match column.as_any().downcast_ref::() { + match self.as_any().downcast_ref::() { Some(list_array) => (0..list_array.len()) .map(|j| (&list_array.value(j)).to_float::()) .collect::, _>>(), None => Err(ArrowError::CastError(format!( - "Can't cast column {i} to fixed size list array" + "Can't cast column {self:?} to fixed size list array" ))), } } diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index d51d6d48..894499e7 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -49,11 +49,11 @@ async fn main() -> Result<(), anyhow::Error> { // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model, "context_vector_index"); + let index = vector_store.index(model, "context_vector_index", SearchParams::new()); // Query the index let results = index - .top_n_from_query("What is a linglingdong?", 1, SearchParams::new()) + .top_n_from_query("What is a linglingdong?", 1) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 67a04664..41bf8500 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -87,8 +87,13 @@ impl MongoDbVectorStore { /// /// The index (of type "vector") must already exist for the MongoDB collection. /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes. - pub fn index(&self, model: M, index_name: &str) -> MongoDbVectorIndex { - MongoDbVectorIndex::new(self.collection.clone(), model, index_name) + pub fn index( + &self, + model: M, + index_name: &str, + search_params: SearchParams, + ) -> MongoDbVectorIndex { + MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params) } } @@ -97,6 +102,7 @@ pub struct MongoDbVectorIndex { collection: mongodb::Collection, model: M, index_name: String, + search_params: SearchParams, } impl MongoDbVectorIndex { @@ -104,11 +110,13 @@ impl MongoDbVectorIndex { collection: mongodb::Collection, model: M, index_name: &str, + search_params: SearchParams, ) -> Self { Self { collection, model, index_name: index_name.to_string(), + search_params, } } } @@ -134,6 +142,21 @@ impl SearchParams { num_candidates: None, } } + + pub fn filter(mut self, filter: mongodb::bson::Document) -> Self { + self.filter = filter; + self + } + + pub fn exact(mut self, exact: bool) -> Self { + self.exact = Some(exact); + self + } + + pub fn num_candidates(mut self, num_candidates: u32) -> Self { + self.num_candidates = Some(num_candidates); + self + } } impl Default for SearchParams { @@ -147,19 +170,22 @@ impl VectorStoreIndex for MongoDbV &self, query: &str, n: usize, - search_params: Self::SearchParams, ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; - self.top_n_from_embedding(&prompt_embedding, n, search_params) - .await + self.top_n_from_embedding(&prompt_embedding, n).await } async fn top_n_from_embedding( &self, prompt_embedding: &Embedding, n: usize, - search_params: Self::SearchParams, ) -> Result, VectorStoreError> { + let SearchParams { + filter, + exact, + num_candidates, + } = &self.search_params; + let mut cursor = self .collection .aggregate( @@ -168,11 +194,11 @@ impl VectorStoreIndex for MongoDbV "$vectorSearch": { "queryVector": &prompt_embedding.vec, "index": &self.index_name, - "exact": search_params.exact.unwrap_or(false), + "exact": exact.unwrap_or(false), "path": "embeddings.vec", - "numCandidates": search_params.num_candidates.unwrap_or((n * 10) as u32), + "numCandidates": num_candidates.unwrap_or((n * 10) as u32), "limit": n as u32, - "filter": &search_params.filter, + "filter": filter, } }, doc! { @@ -206,6 +232,4 @@ impl VectorStoreIndex for MongoDbV Ok(results) } - - type SearchParams = SearchParams; }