Skip to content

Commit

Permalink
Feat(integration-tests): LanceDB (#136)
Browse files Browse the repository at this point in the history
* feat: add lancedb integration test

* fix: fix doctests

* fix: cargo clippy
  • Loading branch information
marieaurore123 authored Dec 5, 2024
1 parent 5979c64 commit 412ea16
Show file tree
Hide file tree
Showing 9 changed files with 308 additions and 84 deletions.
4 changes: 4 additions & 0 deletions rig-lancedb/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ required-features = ["rig-core/derive"]
[[example]]
name = "vector_search_s3_ann"
required-features = ["rig-core/derive"]

[[test]]
name = "integration_tests"
required-features = ["rig-core/derive"]
18 changes: 9 additions & 9 deletions rig-lancedb/examples/fixtures/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@ use rig::{Embed, OneOrMany};
use serde::Deserialize;

#[derive(Embed, Clone, Deserialize, Debug)]
pub struct WordDefinition {
pub struct Word {
pub id: String,
#[embed]
pub definition: String,
}

pub fn word_definitions() -> Vec<WordDefinition> {
pub fn words() -> Vec<Word> {
vec![
WordDefinition {
Word {
id: "doc0".to_string(),
definition: "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.".to_string()
},
WordDefinition {
Word {
id: "doc1".to_string(),
definition: "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.".to_string()
},
WordDefinition {
Word {
id: "doc2".to_string(),
definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string()
}
Expand All @@ -46,22 +46,22 @@ pub fn schema(dims: usize) -> Schema {
]))
}

// Convert WordDefinition objects and their embedding to a RecordBatch.
// Convert Word objects and their embedding to a RecordBatch.
pub fn as_record_batch(
records: Vec<(WordDefinition, OneOrMany<Embedding>)>,
records: Vec<(Word, OneOrMany<Embedding>)>,
dims: usize,
) -> Result<RecordBatch, lancedb::arrow::arrow_schema::ArrowError> {
let id = StringArray::from_iter_values(
records
.iter()
.map(|(WordDefinition { id, .. }, _)| id)
.map(|(Word { id, .. }, _)| id)
.collect::<Vec<_>>(),
);

let definition = StringArray::from_iter_values(
records
.iter()
.map(|(WordDefinition { definition, .. }, _)| definition)
.map(|(Word { definition, .. }, _)| definition)
.collect::<Vec<_>>(),
);

Expand Down
13 changes: 6 additions & 7 deletions rig-lancedb/examples/vector_search_local_ann.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{env, sync::Arc};
use std::sync::Arc;

use arrow_array::RecordBatchIterator;
use fixture::{as_record_batch, schema, word_definitions, WordDefinition};
use fixture::{as_record_batch, schema, words, Word};
use lancedb::index::vector::IvfPqIndexBuilder;
use rig::{
embeddings::{EmbeddingModel, EmbeddingsBuilder},
Expand All @@ -16,8 +16,7 @@ mod fixture;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo).
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let openai_client = Client::new(&openai_api_key);
let openai_client = Client::from_env();

// Select an embedding model.
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
Expand All @@ -27,11 +26,11 @@ async fn main() -> Result<(), anyhow::Error> {

// Generate embeddings for the test data.
let embeddings = EmbeddingsBuilder::new(model.clone())
.documents(word_definitions())?
.documents(words())?
// Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes.
.documents(
(0..256)
.map(|i| WordDefinition {
.map(|i| Word {
id: format!("doc{}", i),
definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string()
})
Expand Down Expand Up @@ -65,7 +64,7 @@ async fn main() -> Result<(), anyhow::Error> {

// Query the index
let results = vector_store_index
.top_n::<WordDefinition>("My boss says I zindle too much, what does that mean?", 1)
.top_n::<Word>("My boss says I zindle too much, what does that mean?", 1)
.await?;

println!("Results: {:?}", results);
Expand Down
9 changes: 4 additions & 5 deletions rig-lancedb/examples/vector_search_local_enn.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{env, sync::Arc};
use std::sync::Arc;

use arrow_array::RecordBatchIterator;
use fixture::{as_record_batch, schema, word_definitions};
use fixture::{as_record_batch, schema, words};
use rig::{
embeddings::{EmbeddingModel, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
Expand All @@ -15,15 +15,14 @@ mod fixture;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo).
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let openai_client = Client::new(&openai_api_key);
let openai_client = Client::from_env();

// Select the embedding model and generate our embeddings
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

// Generate embeddings for the test data.
let embeddings = EmbeddingsBuilder::new(model.clone())
.documents(word_definitions())?
.documents(words())?
.build()
.await?;

Expand Down
13 changes: 6 additions & 7 deletions rig-lancedb/examples/vector_search_s3_ann.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{env, sync::Arc};
use std::sync::Arc;

use arrow_array::RecordBatchIterator;
use fixture::{as_record_batch, schema, word_definitions, WordDefinition};
use fixture::{as_record_batch, schema, words, Word};
use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType};
use rig::{
embeddings::{EmbeddingModel, EmbeddingsBuilder},
Expand All @@ -18,8 +18,7 @@ mod fixture;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo).
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let openai_client = Client::new(&openai_api_key);
let openai_client = Client::from_env();

// Select the embedding model and generate our embeddings
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
Expand All @@ -33,11 +32,11 @@ async fn main() -> Result<(), anyhow::Error> {

// Generate embeddings for the test data.
let embeddings = EmbeddingsBuilder::new(model.clone())
.documents(word_definitions())?
.documents(words())?
// Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes.
.documents(
(0..256)
.map(|i| WordDefinition {
.map(|i| Word {
id: format!("doc{}", i),
definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string()
})
Expand Down Expand Up @@ -77,7 +76,7 @@ async fn main() -> Result<(), anyhow::Error> {

// Query the index
let results = vector_store
.top_n::<WordDefinition>("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1)
.top_n::<Word>("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1)
.await?;

println!("Results: {:?}", results);
Expand Down
43 changes: 28 additions & 15 deletions rig-lancedb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use rig::{
};
use serde::Deserialize;
use serde_json::Value;
use utils::{FilterEmbeddings, QueryToJson};
use utils::{FilterTableColumns, QueryToJson};

mod utils;

Expand All @@ -24,10 +24,12 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError {
/// # Example
/// ```
/// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
/// use rig::embeddings::EmbeddingModel;
/// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel};
///
/// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here.
/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
/// let openai_client = Client::from_env();
///
/// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here.
/// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
/// ```
pub struct LanceDbVectorIndex<M: EmbeddingModel> {
Expand Down Expand Up @@ -112,7 +114,7 @@ pub enum SearchType {
/// Parameters used to perform a vector search on a LanceDb table.
/// # Example
/// ```
/// let search_params = SearchParams::default().distance_type(DistanceType::Cosine);
/// let search_params = rig_lancedb::SearchParams::default().distance_type(lancedb::DistanceType::Cosine);
/// ```
#[derive(Debug, Clone, Default)]
pub struct SearchParams {
Expand Down Expand Up @@ -179,7 +181,9 @@ impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorIndex<M>
/// # Example
/// ```
/// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
/// use rig::embeddings::EmbeddingModel;
/// use rig::providers::openai::{EmbeddingModel, Client, TEXT_EMBEDDING_ADA_002};
///
/// let openai_client = Client::from_env();
///
/// let table: lancedb::Table = db.create_table("fake_definitions"); // <-- Replace with your lancedb table here.
/// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
Expand All @@ -201,27 +205,31 @@ impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorIndex<M>
.table
.vector_search(prompt_embedding.vec.clone())
.map_err(lancedb_to_rig_error)?
.limit(n);
.limit(n)
.select(lancedb::query::Select::Columns(
self.table
.schema()
.await
.map_err(lancedb_to_rig_error)?
.filter_embeddings(),
));

self.build_query(query)
.execute_query()
.await?
.into_iter()
.enumerate()
.map(|(i, value)| {
let filtered_value = value
.filter(self.search_params.column.clone())
.map_err(serde_to_rig_error)?;
Ok((
match filtered_value.get("_distance") {
match value.get("_distance") {
Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
_ => 0.0,
},
match filtered_value.get(self.id_field.clone()) {
match value.get(self.id_field.clone()) {
Some(Value::String(id)) => id.to_string(),
_ => format!("unknown{i}"),
},
serde_json::from_value(filtered_value).map_err(serde_to_rig_error)?,
serde_json::from_value(value).map_err(serde_to_rig_error)?,
))
})
.collect()
Expand All @@ -230,8 +238,13 @@ impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorIndex<M>
/// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`.
/// # Example
/// ```
/// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here.
/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
/// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
/// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel};
///
/// let openai_client = Client::from_env();
///
/// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here.
/// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
///
/// // Query the index
Expand Down
Loading

0 comments on commit 412ea16

Please sign in to comment.