diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 290e50ac..f8c44e8f 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -35,8 +35,8 @@ tracing-subscriber = "0.3.18" tokio-test = "0.4.4" [features] +all = ["derive", "pdf"] derive = ["dep:rig-derive"] -all = ["pdf"] pdf = ["dep:lopdf"] [[test]] diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index cecd20ce..376c37db 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -10,17 +10,24 @@ use rig::{ use serde::Serialize; // Data to be RAGged. -// A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `FakeDefinition` +// A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `WordDefinition` // and tag that field with `#[embed]`. #[derive(Embed, Serialize, Clone, Debug, Eq, PartialEq, Default)] -struct FakeDefinition { +struct WordDefinition { id: String, + word: String, #[embed] definitions: Vec, } #[tokio::main] async fn main() -> Result<(), anyhow::Error> { + // Initialize tracing + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .with_target(false) + .init(); + // Create OpenAI client let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); @@ -30,25 +37,28 @@ async fn main() -> Result<(), anyhow::Error> { // Generate embeddings for the definitions of all the documents using the specified embedding model. let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .documents(vec![ - FakeDefinition { + WordDefinition { id: "doc0".to_string(), + word: "flurbo".to_string(), definitions: vec![ - "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets.".to_string(), - "Definition of a *flurbo*: A fictional digital currency that originated in the animated series Rick and Morty.".to_string() + "1. *flurbo* (name): A flurbo is a green alien that lives on cold planets.".to_string(), + "2. *flurbo* (name): A fictional digital currency that originated in the animated series Rick and Morty.".to_string() ] }, - FakeDefinition { + WordDefinition { id: "doc1".to_string(), + word: "glarb-glarb".to_string(), definitions: vec![ - "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "Definition of a *glarb-glarb*: A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + "1. *glarb-glarb* (noun): A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "2. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() ] }, - FakeDefinition { + WordDefinition { id: "doc2".to_string(), + word: "linglingdong".to_string(), definitions: vec![ - "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), - "Definition of a *linglingdong*: A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() + "1. *linglingdong* (noun): A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + "2. *linglingdong* (noun): A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() ] }, ])? diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 459b017b..bc92f7c5 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -137,11 +137,6 @@ async fn main() -> Result<(), anyhow::Error> { .with_max_level(tracing::Level::INFO) // disable printing the name of the module in every log line. .with_target(false) - // this needs to be set to false, otherwise ANSI color codes will - // show up in a confusing manner in CloudWatch logs. - .with_ansi(false) - // disabling time is handy because CloudWatch will add the ingestion time. - .without_time() .init(); // Create OpenAI client diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 622fa0ff..e8cbf894 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. #[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -struct FakeDefinition { +struct WordDefinition { id: String, word: String, #[embed] @@ -28,7 +28,7 @@ async fn main() -> Result<(), anyhow::Error> { let embeddings = EmbeddingsBuilder::new(model.clone()) .documents(vec![ - FakeDefinition { + WordDefinition { id: "doc0".to_string(), word: "flurbo".to_string(), definitions: vec![ @@ -36,7 +36,7 @@ async fn main() -> Result<(), anyhow::Error> { "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() ] }, - FakeDefinition { + WordDefinition { id: "doc1".to_string(), word: "glarb-glarb".to_string(), definitions: vec![ @@ -44,7 +44,7 @@ async fn main() -> Result<(), anyhow::Error> { "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() ] }, - FakeDefinition { + WordDefinition { id: "doc2".to_string(), word: "linglingdong".to_string(), definitions: vec![ @@ -61,7 +61,7 @@ async fn main() -> Result<(), anyhow::Error> { .index(model); let results = index - .top_n::("I need to buy something in a fictional universe. What type of money can I use for this?", 1) + .top_n::("I need to buy something in a fictional universe. What type of money can I use for this?", 1) .await? .into_iter() .map(|(score, id, doc)| (score, id, doc.word)) diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index f3a97498..aace89fa 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. #[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -struct FakeDefinition { +struct WordDefinition { id: String, word: String, #[embed] @@ -29,7 +29,7 @@ async fn main() -> Result<(), anyhow::Error> { let embeddings = EmbeddingsBuilder::new(document_model.clone()) .documents(vec![ - FakeDefinition { + WordDefinition { id: "doc0".to_string(), word: "flurbo".to_string(), definitions: vec![ @@ -37,7 +37,7 @@ async fn main() -> Result<(), anyhow::Error> { "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() ] }, - FakeDefinition { + WordDefinition { id: "doc1".to_string(), word: "glarb-glarb".to_string(), definitions: vec![ @@ -45,7 +45,7 @@ async fn main() -> Result<(), anyhow::Error> { "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() ] }, - FakeDefinition { + WordDefinition { id: "doc2".to_string(), word: "linglingdong".to_string(), definitions: vec![ @@ -62,7 +62,7 @@ async fn main() -> Result<(), anyhow::Error> { .index(search_model); let results = index - .top_n::( + .top_n::( "Which instrument is found in the Nebulon Mountain Ranges?", 1, ) diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 8f0a5dd1..5a4b63c9 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -13,40 +13,7 @@ use crate::{ /// Builder for creating a collection of embeddings from a vector of documents of type `T`. /// Accumulate documents such that they can be embedded in a single batch to limit api calls to the provider. -pub struct EmbeddingsBuilder { - model: M, - documents: Vec<(T, Vec)>, -} - -impl EmbeddingsBuilder { - /// Create a new embedding builder with the given embedding model - pub fn new(model: M) -> Self { - Self { - model, - documents: vec![], - } - } - - /// Add a document that implements `Embed` to the builder. - pub fn document(mut self, document: T) -> Result { - let mut embedder = TextEmbedder::default(); - document.embed(&mut embedder)?; - - self.documents.push((document, embedder.texts)); - - Ok(self) - } - - /// Add many documents that implement `Embed` to the builder. - pub fn documents(self, documents: impl IntoIterator) -> Result { - let builder = documents - .into_iter() - .try_fold(self, |builder, doc| builder.document(doc))?; - - Ok(builder) - } -} - +/// /// # Example /// ```rust /// use std::env; @@ -62,7 +29,7 @@ impl EmbeddingsBuilder { /// // Shape of data that needs to be RAG'ed. /// // The definition field will be used to generate embeddings. /// #[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -/// struct FakeDefinition { +/// struct WordDefinition { /// id: String, /// word: String, /// #[embed] @@ -77,7 +44,7 @@ impl EmbeddingsBuilder { /// /// let embeddings = EmbeddingsBuilder::new(model.clone()) /// .documents(vec![ -/// FakeDefinition { +/// WordDefinition { /// id: "doc0".to_string(), /// word: "flurbo".to_string(), /// definitions: vec![ @@ -85,7 +52,7 @@ impl EmbeddingsBuilder { /// "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() /// ] /// }, -/// FakeDefinition { +/// WordDefinition { /// id: "doc1".to_string(), /// word: "glarb-glarb".to_string(), /// definitions: vec![ @@ -93,7 +60,7 @@ impl EmbeddingsBuilder { /// "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() /// ] /// }, -/// FakeDefinition { +/// WordDefinition { /// id: "doc2".to_string(), /// word: "linglingdong".to_string(), /// definitions: vec![ @@ -105,6 +72,40 @@ impl EmbeddingsBuilder { /// .build() /// .await?; /// ``` +pub struct EmbeddingsBuilder { + model: M, + documents: Vec<(T, Vec)>, +} + +impl EmbeddingsBuilder { + /// Create a new embedding builder with the given embedding model + pub fn new(model: M) -> Self { + Self { + model, + documents: vec![], + } + } + + /// Add a document that implements `Embed` to the builder. + pub fn document(mut self, document: T) -> Result { + let mut embedder = TextEmbedder::default(); + document.embed(&mut embedder)?; + + self.documents.push((document, embedder.texts)); + + Ok(self) + } + + /// Add many documents that implement `Embed` to the builder. + pub fn documents(self, documents: impl IntoIterator) -> Result { + let builder = documents + .into_iter() + .try_fold(self, |builder, doc| builder.document(doc))?; + + Ok(builder) + } +} + impl EmbeddingsBuilder { /// Generate embeddings for all documents in the builder. /// Returns a vector of tuples, where the first element is the document and the second element is the embeddings (either one embedding or many). @@ -174,9 +175,9 @@ mod tests { use super::EmbeddingsBuilder; #[derive(Clone)] - struct FakeModel; + struct Model; - impl EmbeddingModel for FakeModel { + impl EmbeddingModel for Model { const MAX_DOCUMENTS: usize = 5; fn ndims(&self) -> usize { @@ -198,12 +199,12 @@ mod tests { } #[derive(Clone, Debug)] - struct FakeDefinition { + struct WordDefinition { id: String, definitions: Vec, } - impl Embed for FakeDefinition { + impl Embed for WordDefinition { fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { for definition in &self.definitions { embedder.embed(definition.clone()); @@ -212,16 +213,16 @@ mod tests { } } - fn fake_definitions_multiple_text() -> Vec { + fn definitions_multiple_text() -> Vec { vec![ - FakeDefinition { + WordDefinition { id: "doc0".to_string(), definitions: vec![ "A green alien that lives on cold planets.".to_string(), "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() ] }, - FakeDefinition { + WordDefinition { id: "doc1".to_string(), definitions: vec![ "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), @@ -231,13 +232,13 @@ mod tests { ] } - fn fake_definitions_multiple_text_2() -> Vec { + fn definitions_multiple_text_2() -> Vec { vec![ - FakeDefinition { + WordDefinition { id: "doc2".to_string(), definitions: vec!["Another fake definitions".to_string()], }, - FakeDefinition { + WordDefinition { id: "doc3".to_string(), definitions: vec!["Some fake definition".to_string()], }, @@ -245,25 +246,25 @@ mod tests { } #[derive(Clone, Debug)] - struct FakeDefinitionSingle { + struct WordDefinitionSingle { id: String, definition: String, } - impl Embed for FakeDefinitionSingle { + impl Embed for WordDefinitionSingle { fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(self.definition.clone()); Ok(()) } } - fn fake_definitions_single_text() -> Vec { + fn definitions_single_text() -> Vec { vec![ - FakeDefinitionSingle { + WordDefinitionSingle { id: "doc0".to_string(), definition: "A green alien that lives on cold planets.".to_string(), }, - FakeDefinitionSingle { + WordDefinitionSingle { id: "doc1".to_string(), definition: "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), } @@ -272,9 +273,9 @@ mod tests { #[tokio::test] async fn test_build_multiple_text() { - let fake_definitions = fake_definitions_multiple_text(); + let fake_definitions = definitions_multiple_text(); - let fake_model = FakeModel; + let fake_model = Model; let mut result = EmbeddingsBuilder::new(fake_model) .documents(fake_definitions) .unwrap() @@ -306,9 +307,9 @@ mod tests { #[tokio::test] async fn test_build_single_text() { - let fake_definitions = fake_definitions_single_text(); + let fake_definitions = definitions_single_text(); - let fake_model = FakeModel; + let fake_model = Model; let mut result = EmbeddingsBuilder::new(fake_model) .documents(fake_definitions) .unwrap() @@ -340,10 +341,10 @@ mod tests { #[tokio::test] async fn test_build_multiple_and_single_text() { - let fake_definitions = fake_definitions_multiple_text(); - let fake_definitions_single = fake_definitions_multiple_text_2(); + let fake_definitions = definitions_multiple_text(); + let fake_definitions_single = definitions_multiple_text_2(); - let fake_model = FakeModel; + let fake_model = Model; let mut result = EmbeddingsBuilder::new(fake_model) .documents(fake_definitions) .unwrap() @@ -377,10 +378,10 @@ mod tests { #[tokio::test] async fn test_build_string() { - let bindings = fake_definitions_multiple_text(); + let bindings = definitions_multiple_text(); let fake_definitions = bindings.iter().map(|def| def.definitions.clone()); - let fake_model = FakeModel; + let fake_model = Model; let mut result = EmbeddingsBuilder::new(fake_model) .documents(fake_definitions) .unwrap() diff --git a/rig-core/src/embeddings/embed.rs b/rig-core/src/embeddings/embed.rs index 6b30bbc0..480a2930 100644 --- a/rig-core/src/embeddings/embed.rs +++ b/rig-core/src/embeddings/embed.rs @@ -29,15 +29,15 @@ impl EmbedError { /// use std::env; /// /// use serde::{Deserialize, Serialize}; -/// use rig::{Embed, embeddings::{TextEmbedder, EmbedError, to_texts}}; +/// use rig::{Embed, embeddings::{TextEmbedder, EmbedError}}; /// -/// struct FakeDefinition { +/// struct WordDefinition { /// id: String, /// word: String, /// definitions: String, /// } /// -/// impl Embed for FakeDefinition { +/// impl Embed for WordDefinition { /// fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { /// // Embeddings only need to be generated for `definition` field. /// // Split the definitions by comma and collect them into a vector of strings. @@ -52,13 +52,13 @@ impl EmbedError { /// } /// } /// -/// let fake_definition = FakeDefinition { +/// let fake_definition = WordDefinition { /// id: "1".to_string(), /// word: "apple".to_string(), /// definitions: "a fruit, a tech company".to_string(), /// }; /// -/// assert_eq!(to_texts(fake_definition).unwrap(), vec!["a fruit", " a tech company"]); +/// assert_eq!(embeddings::to_texts(fake_definition).unwrap(), vec!["a fruit", " a tech company"]); /// ``` pub trait Embed { fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError>; @@ -174,6 +174,12 @@ impl Embed for serde_json::Value { } } +impl Embed for &T { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + (*self).embed(embedder) + } +} + impl Embed for Vec { fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { for item in self { diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 4fe84a6b..f1b5427b 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -91,7 +91,7 @@ pub mod tool; pub mod vector_store; // Re-export commonly used types and traits -pub use embeddings::{to_texts, Embed}; +pub use embeddings::Embed; pub use one_or_many::{EmptyListError, OneOrMany}; #[cfg(feature = "derive")] diff --git a/rig-core/src/loaders/file.rs b/rig-core/src/loaders/file.rs index 17c2f1f3..84eb1864 100644 --- a/rig-core/src/loaders/file.rs +++ b/rig-core/src/loaders/file.rs @@ -162,7 +162,7 @@ impl<'a, T: 'a> FileLoader<'a, Result> { } } -impl<'a> FileLoader<'a, Result> { +impl FileLoader<'_, Result> { /// Creates a new [FileLoader] using a glob pattern to match files. /// /// # Example @@ -227,7 +227,7 @@ impl<'a, T> IntoIterator for FileLoader<'a, T> { } } -impl<'a, T> Iterator for IntoIter<'a, T> { +impl Iterator for IntoIter<'_, T> { type Item = T; fn next(&mut self) -> Option { diff --git a/rig-core/src/loaders/pdf.rs b/rig-core/src/loaders/pdf.rs index ea18e4e6..410643f8 100644 --- a/rig-core/src/loaders/pdf.rs +++ b/rig-core/src/loaders/pdf.rs @@ -335,7 +335,7 @@ impl<'a, T: 'a> PdfFileLoader<'a, Result> { } } -impl<'a> PdfFileLoader<'a, Result> { +impl PdfFileLoader<'_, Result> { /// Creates a new [PdfFileLoader] using a glob pattern to match files. /// /// # Example @@ -396,7 +396,7 @@ impl<'a, T> IntoIterator for PdfFileLoader<'a, T> { } } -impl<'a, T> Iterator for IntoIter<'a, T> { +impl Iterator for IntoIter<'_, T> { type Item = T; fn next(&mut self) -> Option { diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 77d3c5a8..6b5a3079 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -219,7 +219,7 @@ pub struct EmbeddingData { pub index: usize, } -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct Usage { pub prompt_tokens: usize, pub total_tokens: usize, @@ -229,7 +229,7 @@ impl std::fmt::Display for Usage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "Prompt tokens: {}\nTotal tokens: {}", + "Prompt tokens: {} Total tokens: {}", self.prompt_tokens, self.total_tokens ) } @@ -535,7 +535,7 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "OpenAI completion token usage: {:?}", - response.usage + response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) ); response.try_into() } diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index 41be57f7..fa1e34fb 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -155,7 +155,7 @@ impl std::fmt::Display for Usage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "Prompt tokens: {}\nCompletion tokens: {}\nTotal tokens: {}", + "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}", self.prompt_tokens, self.completion_tokens, self.total_tokens ) } diff --git a/rig-core/tests/embed_macro.rs b/rig-core/tests/embed_macro.rs index 778b70bd..daf6dcf8 100644 --- a/rig-core/tests/embed_macro.rs +++ b/rig-core/tests/embed_macro.rs @@ -1,33 +1,38 @@ use rig::{ - embeddings::{embed::EmbedError, TextEmbedder}, - to_texts, Embed, + embeddings::{self, embed::EmbedError, TextEmbedder}, + Embed, }; use serde::Serialize; -fn serialize(embedder: &mut TextEmbedder, definition: Definition) -> Result<(), EmbedError> { - embedder.embed(serde_json::to_string(&definition).map_err(EmbedError::new)?); - - Ok(()) -} - -#[derive(Embed)] -struct FakeDefinition { - id: String, - word: String, - #[embed(embed_with = "serialize")] - definition: Definition, -} - -#[derive(Serialize, Clone)] -struct Definition { - word: String, - link: String, - speech: String, -} - #[test] fn test_custom_embed() { - let fake_definition = FakeDefinition { + #[derive(Embed)] + struct WordDefinition { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + word: String, + #[embed(embed_with = "custom_embedding_function")] + definition: Definition, + } + + #[derive(Serialize, Clone)] + struct Definition { + word: String, + link: String, + speech: String, + } + + fn custom_embedding_function( + embedder: &mut TextEmbedder, + definition: Definition, + ) -> Result<(), EmbedError> { + embedder.embed(serde_json::to_string(&definition).map_err(EmbedError::new)?); + + Ok(()) + } + + let definition = WordDefinition { id: "doc1".to_string(), word: "house".to_string(), definition: Definition { @@ -37,30 +42,41 @@ fn test_custom_embed() { }, }; - println!( - "FakeDefinition: {}, {}", - fake_definition.id, fake_definition.word - ); - assert_eq!( - to_texts(fake_definition).unwrap().first().unwrap().clone(), - "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() - + embeddings::to_texts(definition).unwrap(), + vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] ) } -#[derive(Embed)] -struct FakeDefinition2 { - id: String, - #[embed] - word: String, - #[embed(embed_with = "serialize")] - definition: Definition, -} - #[test] fn test_custom_and_basic_embed() { - let fake_definition = FakeDefinition2 { + #[derive(Embed)] + struct WordDefinition { + #[allow(dead_code)] + id: String, + #[embed] + word: String, + #[embed(embed_with = "custom_embedding_function")] + definition: Definition, + } + + #[derive(Serialize, Clone)] + struct Definition { + word: String, + link: String, + speech: String, + } + + fn custom_embedding_function( + embedder: &mut TextEmbedder, + definition: Definition, + ) -> Result<(), EmbedError> { + embedder.embed(serde_json::to_string(&definition).map_err(EmbedError::new)?); + + Ok(()) + } + + let definition = WordDefinition { id: "doc1".to_string(), word: "house".to_string(), definition: Definition { @@ -70,69 +86,63 @@ fn test_custom_and_basic_embed() { }, }; - println!( - "FakeDefinition: {}, {}", - fake_definition.id, fake_definition.word - ); - - let texts = to_texts(fake_definition).unwrap(); - - assert_eq!(texts.first().unwrap().clone(), "house".to_string()); + let texts = embeddings::to_texts(definition).unwrap(); assert_eq!( - texts.last().unwrap().clone(), - "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() - ) -} - -#[derive(Embed)] -struct FakeDefinition3 { - id: String, - word: String, - #[embed] - definition: String, + texts, + vec![ + "house".to_string(), + "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() + ] + ); } #[test] fn test_single_embed() { + #[derive(Embed)] + struct WordDefinition { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + word: String, + #[embed] + definition: String, + } + let definition = "a building in which people live; residence for human beings.".to_string(); - let fake_definition = FakeDefinition3 { + let word_definition = WordDefinition { id: "doc1".to_string(), word: "house".to_string(), definition: definition.clone(), }; - println!( - "FakeDefinition3: {}, {}", - fake_definition.id, fake_definition.word - ); assert_eq!( - to_texts(fake_definition).unwrap().first().unwrap().clone(), - definition + embeddings::to_texts(word_definition).unwrap(), + vec![definition] ) } -#[derive(Embed)] -struct Company { - id: String, - company: String, - #[embed] - employee_ages: Vec, -} - #[test] -fn test_multiple_embed_strings() { +fn test_embed_vec_non_string() { + #[derive(Embed)] + struct Company { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + company: String, + #[embed] + employee_ages: Vec, + } + let company = Company { id: "doc1".to_string(), company: "Google".to_string(), employee_ages: vec![25, 30, 35, 40], }; - println!("Company: {}, {}", company.id, company.company); - assert_eq!( - to_texts(company).unwrap(), + embeddings::to_texts(company).unwrap(), vec![ "25".to_string(), "30".to_string(), @@ -142,27 +152,60 @@ fn test_multiple_embed_strings() { ); } -#[derive(Embed)] -struct Company2 { - id: String, - #[embed] - company: String, - #[embed] - employee_ages: Vec, +#[test] +fn test_embed_vec_string() { + #[derive(Embed)] + struct Company { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + company: String, + #[embed] + employee_names: Vec, + } + + let company = Company { + id: "doc1".to_string(), + company: "Google".to_string(), + employee_names: vec![ + "Alice".to_string(), + "Bob".to_string(), + "Charlie".to_string(), + "David".to_string(), + ], + }; + + assert_eq!( + embeddings::to_texts(company).unwrap(), + vec![ + "Alice".to_string(), + "Bob".to_string(), + "Charlie".to_string(), + "David".to_string() + ] + ); } #[test] fn test_multiple_embed_tags() { - let company = Company2 { + #[derive(Embed)] + struct Company { + #[allow(dead_code)] + id: String, + #[embed] + company: String, + #[embed] + employee_ages: Vec, + } + + let company = Company { id: "doc1".to_string(), company: "Google".to_string(), employee_ages: vec![25, 30, 35, 40], }; - println!("Company: {}", company.id); - assert_eq!( - to_texts(company).unwrap(), + embeddings::to_texts(company).unwrap(), vec![ "Google".to_string(), "25".to_string(), diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 954494e5..b12156fb 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -7,23 +7,23 @@ use rig::{Embed, OneOrMany}; use serde::Deserialize; #[derive(Embed, Clone, Deserialize, Debug)] -pub struct FakeDefinition { +pub struct WordDefinition { pub id: String, #[embed] pub definition: String, } -pub fn fake_definitions() -> Vec { +pub fn fake_definitions() -> Vec { vec![ - FakeDefinition { + WordDefinition { id: "doc0".to_string(), definition: "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.".to_string() }, - FakeDefinition { + WordDefinition { id: "doc1".to_string(), definition: "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.".to_string() }, - FakeDefinition { + WordDefinition { id: "doc2".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() } @@ -46,22 +46,22 @@ pub fn schema(dims: usize) -> Schema { ])) } -// Convert FakeDefinition objects and their embedding to a RecordBatch. +// Convert WordDefinition objects and their embedding to a RecordBatch. pub fn as_record_batch( - records: Vec<(FakeDefinition, OneOrMany)>, + records: Vec<(WordDefinition, OneOrMany)>, dims: usize, ) -> Result { let id = StringArray::from_iter_values( records .iter() - .map(|(FakeDefinition { id, .. }, _)| id) + .map(|(WordDefinition { id, .. }, _)| id) .collect::>(), ); let definition = StringArray::from_iter_values( records .iter() - .map(|(FakeDefinition { definition, .. }, _)| definition) + .map(|(WordDefinition { definition, .. }, _)| definition) .collect::>(), ); diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 84679e3f..0b75f080 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; +use fixture::{as_record_batch, fake_definitions, schema, WordDefinition}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, @@ -31,7 +31,7 @@ async fn main() -> Result<(), anyhow::Error> { // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) - .map(|i| FakeDefinition { + .map(|i| WordDefinition { id: format!("doc{}", i), definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() }) @@ -65,7 +65,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store_index - .top_n::("My boss says I zindle too much, what does that mean?", 1) + .top_n::("My boss says I zindle too much, what does that mean?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 160dfa10..8aca722b 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; +use fixture::{as_record_batch, fake_definitions, schema, WordDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, @@ -37,7 +37,7 @@ async fn main() -> Result<(), anyhow::Error> { // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) - .map(|i| FakeDefinition { + .map(|i| WordDefinition { id: format!("doc{}", i), definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() }) @@ -77,7 +77,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n::("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) + .top_n::("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs index fd890280..7f9d1d7d 100644 --- a/rig-lancedb/src/utils/deserializer.rs +++ b/rig-lancedb/src/utils/deserializer.rs @@ -356,9 +356,9 @@ fn type_matcher(column: &Arc) -> Result, VectorStoreError> } } -/////////////////////////////////////////////////////////////////////////////////// -/// Everything below includes helpers for the recursive function `type_matcher`./// -/////////////////////////////////////////////////////////////////////////////////// +// ================================================================ +// Everything below includes helpers for the recursive function `type_matcher` +// ================================================================ /// Trait used to "deserialize" an arrow_array::Array as as list of primitive objects. trait DeserializePrimitiveArray { diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index ecdb7a9f..5d0ed81b 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -11,7 +11,7 @@ use rig_mongodb::{MongoDbVectorIndex, SearchParams}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. #[derive(Embed, Clone, Deserialize, Debug)] -struct FakeDefinition { +struct WordDefinition { #[serde(rename = "_id")] id: String, #[embed] @@ -58,15 +58,15 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let fake_definitions = vec![ - FakeDefinition { + WordDefinition { id: "doc0".to_string(), definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), }, - FakeDefinition { + WordDefinition { id: "doc1".to_string(), definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), }, - FakeDefinition { + WordDefinition { id: "doc2".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), } @@ -80,7 +80,7 @@ async fn main() -> Result<(), anyhow::Error> { let mongo_documents = embeddings .iter() .map( - |(FakeDefinition { id, definition, .. }, embedding)| Document { + |(WordDefinition { id, definition, .. }, embedding)| Document { id: id.clone(), definition: definition.clone(), embedding: embedding.first().vec.clone(), @@ -101,7 +101,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = index - .top_n::("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-neo4j/examples/vector_search_simple.rs b/rig-neo4j/examples/vector_search_simple.rs index 0d3acf81..fca43d27 100644 --- a/rig-neo4j/examples/vector_search_simple.rs +++ b/rig-neo4j/examples/vector_search_simple.rs @@ -21,7 +21,7 @@ use rig_neo4j::{ }; #[derive(Embed, Clone, Debug)] -pub struct FakeDefinition { +pub struct WordDefinition { pub id: String, #[embed] pub definition: String, @@ -44,15 +44,15 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(model.clone()) - .document(FakeDefinition { + .document(WordDefinition { id: "doc0".to_string(), definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), })? - .document(FakeDefinition { + .document(WordDefinition { id: "doc1".to_string(), definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), })? - .document(FakeDefinition { + .document(WordDefinition { id: "doc2".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), })? diff --git a/rig-qdrant/examples/qdrant_vector_search.rs b/rig-qdrant/examples/qdrant_vector_search.rs index b1dc2d96..b1a91349 100644 --- a/rig-qdrant/examples/qdrant_vector_search.rs +++ b/rig-qdrant/examples/qdrant_vector_search.rs @@ -25,7 +25,7 @@ use rig_qdrant::QdrantVectorStore; use serde_json::json; #[derive(Embed)] -struct FakeDefinition { +struct WordDefinition { id: String, #[embed] definition: String, @@ -57,15 +57,15 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let documents = EmbeddingsBuilder::new(model.clone()) - .document(FakeDefinition { + .document(WordDefinition { id: "0981d983-a5f8-49eb-89ea-f7d3b2196d2e".to_string(), definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), })? - .document(FakeDefinition { + .document(WordDefinition { id: "62a36d43-80b6-4fd6-990c-f75bb02287d1".to_string(), definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), })? - .document(FakeDefinition { + .document(WordDefinition { id: "f9e17d59-32e5-440c-be02-b2759a654824".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), })?