diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index e09d2d5b..bf97fbe6 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -30,3 +30,7 @@ required-features = ["rig-core/derive"] [[example]] name = "vector_search_s3_ann" required-features = ["rig-core/derive"] + +[[test]] +name = "integration_tests" +required-features = ["rig-core/derive"] \ No newline at end of file diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 94704822..f2de914c 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -7,23 +7,23 @@ use rig::{Embed, OneOrMany}; use serde::Deserialize; #[derive(Embed, Clone, Deserialize, Debug)] -pub struct WordDefinition { +pub struct Word { pub id: String, #[embed] pub definition: String, } -pub fn word_definitions() -> Vec { +pub fn words() -> Vec { vec![ - WordDefinition { + Word { id: "doc0".to_string(), definition: "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.".to_string() }, - WordDefinition { + Word { id: "doc1".to_string(), definition: "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.".to_string() }, - WordDefinition { + Word { id: "doc2".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() } @@ -46,22 +46,22 @@ pub fn schema(dims: usize) -> Schema { ])) } -// Convert WordDefinition objects and their embedding to a RecordBatch. +// Convert Word objects and their embedding to a RecordBatch. pub fn as_record_batch( - records: Vec<(WordDefinition, OneOrMany)>, + records: Vec<(Word, OneOrMany)>, dims: usize, ) -> Result { let id = StringArray::from_iter_values( records .iter() - .map(|(WordDefinition { id, .. }, _)| id) + .map(|(Word { id, .. }, _)| id) .collect::>(), ); let definition = StringArray::from_iter_values( records .iter() - .map(|(WordDefinition { definition, .. }, _)| definition) + .map(|(Word { definition, .. }, _)| definition) .collect::>(), ); diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 03636089..a4415ba3 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,7 +1,7 @@ -use std::{env, sync::Arc}; +use std::sync::Arc; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema, word_definitions, WordDefinition}; +use fixture::{as_record_batch, schema, words, Word}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, @@ -16,8 +16,7 @@ mod fixture; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). - let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); - let openai_client = Client::new(&openai_api_key); + let openai_client = Client::from_env(); // Select an embedding model. let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); @@ -27,11 +26,11 @@ async fn main() -> Result<(), anyhow::Error> { // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(word_definitions())? + .documents(words())? // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) - .map(|i| WordDefinition { + .map(|i| Word { id: format!("doc{}", i), definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() }) @@ -65,7 +64,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store_index - .top_n::("My boss says I zindle too much, what does that mean?", 1) + .top_n::("My boss says I zindle too much, what does that mean?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 0244d33e..5011238d 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -1,7 +1,7 @@ -use std::{env, sync::Arc}; +use std::sync::Arc; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema, word_definitions}; +use fixture::{as_record_batch, schema, words}; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, @@ -15,15 +15,14 @@ mod fixture; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). - let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); - let openai_client = Client::new(&openai_api_key); + let openai_client = Client::from_env(); // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(word_definitions())? + .documents(words())? .build() .await?; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index f296d1d7..61267e8f 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,7 +1,7 @@ -use std::{env, sync::Arc}; +use std::sync::Arc; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema, word_definitions, WordDefinition}; +use fixture::{as_record_batch, schema, words, Word}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, @@ -18,8 +18,7 @@ mod fixture; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). - let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); - let openai_client = Client::new(&openai_api_key); + let openai_client = Client::from_env(); // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); @@ -33,11 +32,11 @@ async fn main() -> Result<(), anyhow::Error> { // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(word_definitions())? + .documents(words())? // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) - .map(|i| WordDefinition { + .map(|i| Word { id: format!("doc{}", i), definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() }) @@ -77,7 +76,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n::("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) + .top_n::("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 7567bc60..04fdbf60 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -8,7 +8,7 @@ use rig::{ }; use serde::Deserialize; use serde_json::Value; -use utils::{FilterEmbeddings, QueryToJson}; +use utils::{FilterTableColumns, QueryToJson}; mod utils; @@ -24,10 +24,12 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { /// # Example /// ``` /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; -/// use rig::embeddings::EmbeddingModel; +/// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel}; /// -/// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. -/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. +/// let openai_client = Client::from_env(); +/// +/// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. +/// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// ``` pub struct LanceDbVectorIndex { @@ -112,7 +114,7 @@ pub enum SearchType { /// Parameters used to perform a vector search on a LanceDb table. /// # Example /// ``` -/// let search_params = SearchParams::default().distance_type(DistanceType::Cosine); +/// let search_params = rig_lancedb::SearchParams::default().distance_type(lancedb::DistanceType::Cosine); /// ``` #[derive(Debug, Clone, Default)] pub struct SearchParams { @@ -179,7 +181,9 @@ impl VectorStoreIndex for LanceDbVectorIndex /// # Example /// ``` /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; - /// use rig::embeddings::EmbeddingModel; + /// use rig::providers::openai::{EmbeddingModel, Client, TEXT_EMBEDDING_ADA_002}; + /// + /// let openai_client = Client::from_env(); /// /// let table: lancedb::Table = db.create_table("fake_definitions"); // <-- Replace with your lancedb table here. /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. @@ -201,7 +205,14 @@ impl VectorStoreIndex for LanceDbVectorIndex .table .vector_search(prompt_embedding.vec.clone()) .map_err(lancedb_to_rig_error)? - .limit(n); + .limit(n) + .select(lancedb::query::Select::Columns( + self.table + .schema() + .await + .map_err(lancedb_to_rig_error)? + .filter_embeddings(), + )); self.build_query(query) .execute_query() @@ -209,19 +220,16 @@ impl VectorStoreIndex for LanceDbVectorIndex .into_iter() .enumerate() .map(|(i, value)| { - let filtered_value = value - .filter(self.search_params.column.clone()) - .map_err(serde_to_rig_error)?; Ok(( - match filtered_value.get("_distance") { + match value.get("_distance") { Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(), _ => 0.0, }, - match filtered_value.get(self.id_field.clone()) { + match value.get(self.id_field.clone()) { Some(Value::String(id)) => id.to_string(), _ => format!("unknown{i}"), }, - serde_json::from_value(filtered_value).map_err(serde_to_rig_error)?, + serde_json::from_value(value).map_err(serde_to_rig_error)?, )) }) .collect() @@ -230,8 +238,13 @@ impl VectorStoreIndex for LanceDbVectorIndex /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. /// # Example /// ``` - /// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. - /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. + /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; + /// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel}; + /// + /// let openai_client = Client::from_env(); + /// + /// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. + /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// /// // Query the index diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index 9b5f6cb5..548c0db0 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -1,10 +1,14 @@ mod deserializer; +use std::sync::Arc; + use deserializer::RecordBatchDeserializer; use futures::TryStreamExt; -use lancedb::query::ExecutableQuery; +use lancedb::{ + arrow::arrow_schema::{DataType, Schema}, + query::ExecutableQuery, +}; use rig::vector_store::VectorStoreError; -use serde::de::Error; use crate::lancedb_to_rig_error; @@ -28,60 +32,82 @@ impl QueryToJson for lancedb::query::VectorQuery { } } -pub(crate) trait FilterEmbeddings { - fn filter(self, embeddings_col: Option) -> serde_json::Result; +/// Filter out the columns from a table that do not include embeddings. Return the vector of column names. +pub(crate) trait FilterTableColumns { + fn filter_embeddings(self) -> Vec; } -impl FilterEmbeddings for serde_json::Value { - fn filter(mut self, embeddings_col: Option) -> serde_json::Result { - match self.as_object_mut() { - Some(obj) => { - obj.remove(&embeddings_col.unwrap_or("embedding".to_string())); - serde_json::to_value(obj) - } - None => Err(serde_json::Error::custom(format!( - "{} is not an object", - self - ))), - } +impl FilterTableColumns for Arc { + fn filter_embeddings(self) -> Vec { + self.fields() + .iter() + .filter_map(|field| match field.data_type() { + DataType::FixedSizeList(inner, ..) => match inner.data_type() { + DataType::Float64 => None, + _ => Some(field.name().to_string()), + }, + _ => Some(field.name().to_string()), + }) + .collect() } } #[cfg(test)] mod tests { - use crate::utils::FilterEmbeddings; + use std::sync::Arc; - #[test] - fn test_filter_default() { - let json = serde_json::json!({ - "id": "doc0", - "text": "Hello world", - "embedding": vec![0.3889, 0.6987, 0.7758, 0.7750, 0.7289, 0.3380, 0.1165, 0.1551, 0.3783, 0.1458, - 0.3060, 0.2155, 0.8966, 0.5498, 0.7419, 0.8120, 0.2306, 0.5155, 0.9947, 0.0805] - }); + use lancedb::arrow::arrow_schema::{DataType, Field, Schema}; - let filtered_json = json.filter(None).unwrap(); + use super::FilterTableColumns; - assert_eq!( - filtered_json, - serde_json::json!({"id": "doc0", "text": "Hello world"}) + #[tokio::test] + async fn test_column_filtering() { + let field_a = Field::new("id", DataType::Int64, false); + let field_b = Field::new("my_bool", DataType::Boolean, false); + let field_c = Field::new( + "my_embeddings", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 10), + false, + ); + let field_d = Field::new( + "my_list", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 10), + false, ); - } - #[test] - fn test_filter_non_default() { - let json = serde_json::json!({ - "id": "doc0", - "text": "Hello world", - "vectors": vec![0.3889, 0.6987, 0.7758, 0.7750, 0.7289, 0.3380, 0.1165, 0.1551, 0.3783, 0.1458, - 0.3060, 0.2155, 0.8966, 0.5498, 0.7419, 0.8120, 0.2306, 0.5155, 0.9947, 0.0805] - }); + let schema = Schema::new(vec![field_a, field_b, field_c, field_d]); - let filtered_json = json.filter(Some("vectors".to_string())).unwrap(); + let columns = Arc::new(schema).filter_embeddings(); assert_eq!( - filtered_json, - serde_json::json!({"id": "doc0", "text": "Hello world"}) + columns, + vec![ + "id".to_string(), + "my_bool".to_string(), + "my_list".to_string() + ] + ) + } + + #[tokio::test] + async fn test_column_filtering_2() { + let field_a = Field::new("id", DataType::Int64, false); + let field_b = Field::new("my_bool", DataType::Boolean, false); + let field_c = Field::new( + "my_embeddings", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 10), + false, + ); + let field_d = Field::new( + "my_other_embeddings", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 10), + false, ); + + let schema = Schema::new(vec![field_a, field_b, field_c, field_d]); + + let columns = Arc::new(schema).filter_embeddings(); + + assert_eq!(columns, vec!["id".to_string(), "my_bool".to_string()]) } } diff --git a/rig-lancedb/tests/fixtures/lib.rs b/rig-lancedb/tests/fixtures/lib.rs new file mode 100644 index 00000000..f2de914c --- /dev/null +++ b/rig-lancedb/tests/fixtures/lib.rs @@ -0,0 +1,90 @@ +use std::sync::Arc; + +use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; +use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; +use rig::embeddings::Embedding; +use rig::{Embed, OneOrMany}; +use serde::Deserialize; + +#[derive(Embed, Clone, Deserialize, Debug)] +pub struct Word { + pub id: String, + #[embed] + pub definition: String, +} + +pub fn words() -> Vec { + vec![ + Word { + id: "doc0".to_string(), + definition: "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.".to_string() + }, + Word { + id: "doc1".to_string(), + definition: "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.".to_string() + }, + Word { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + } + ] +} + +// Schema of table in LanceDB. +pub fn schema(dims: usize) -> Schema { + Schema::new(Fields::from(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("definition", DataType::Utf8, false), + Field::new( + "embedding", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float64, true)), + dims as i32, + ), + false, + ), + ])) +} + +// Convert Word objects and their embedding to a RecordBatch. +pub fn as_record_batch( + records: Vec<(Word, OneOrMany)>, + dims: usize, +) -> Result { + let id = StringArray::from_iter_values( + records + .iter() + .map(|(Word { id, .. }, _)| id) + .collect::>(), + ); + + let definition = StringArray::from_iter_values( + records + .iter() + .map(|(Word { definition, .. }, _)| definition) + .collect::>(), + ); + + let embedding = FixedSizeListArray::from_iter_primitive::( + records + .into_iter() + .map(|(_, embeddings)| { + Some( + embeddings + .first() + .vec + .into_iter() + .map(Some) + .collect::>(), + ) + }) + .collect::>(), + dims as i32, + ); + + RecordBatch::try_from_iter(vec![ + ("id", Arc::new(id) as ArrayRef), + ("definition", Arc::new(definition) as ArrayRef), + ("embedding", Arc::new(embedding) as ArrayRef), + ]) +} diff --git a/rig-lancedb/tests/integration_tests.rs b/rig-lancedb/tests/integration_tests.rs new file mode 100644 index 00000000..29a427fb --- /dev/null +++ b/rig-lancedb/tests/integration_tests.rs @@ -0,0 +1,94 @@ +use serde_json::json; + +use arrow_array::RecordBatchIterator; +use fixture::{as_record_batch, schema, words, Word}; +use lancedb::index::vector::IvfPqIndexBuilder; +use rig::{ + embeddings::{EmbeddingModel, EmbeddingsBuilder}, + providers::openai::{self, Client}, + vector_store::VectorStoreIndex, +}; +use rig_lancedb::{LanceDbVectorIndex, SearchParams}; +use std::sync::Arc; + +#[path = "./fixtures/lib.rs"] +mod fixture; + +#[tokio::test] +async fn vector_search_test() { + // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). + let openai_client = Client::from_env(); + + // Select an embedding model. + let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); + + // Initialize LanceDB locally. + let db = lancedb::connect("data/lancedb-store") + .execute() + .await + .unwrap(); + + // Generate embeddings for the test data. + let embeddings = EmbeddingsBuilder::new(model.clone()) + .documents(words()).unwrap() + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + .documents( + (0..256) + .map(|i| Word { + id: format!("doc{}", i), + definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() + }) + ).unwrap() + .build() + .await.unwrap(); + + let table = db + .create_table( + "words", + RecordBatchIterator::new( + vec![as_record_batch(embeddings, model.ndims())], + Arc::new(schema(model.ndims())), + ), + ) + .execute() + .await + .unwrap(); + + // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information + table + .create_index( + &["embedding"], + lancedb::index::Index::IvfPq(IvfPqIndexBuilder::default()), + ) + .execute() + .await + .unwrap(); + + // Define search_params params that will be used by the vector store to perform the vector search. + let search_params = SearchParams::default(); + let vector_store_index = LanceDbVectorIndex::new(table, model, "id", search_params) + .await + .unwrap(); + + // Query the index + let results = vector_store_index + .top_n::( + "My boss says I zindle too much, what does that mean.unwrap()", + 1, + ) + .await + .unwrap(); + + let (distance, _, value) = &results.first().unwrap(); + + assert_eq!( + *value, + json!({ + "_distance": distance, + "definition": "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.", + "id": "doc1" + }) + ); + + db.drop_db().await.unwrap(); +}