From b7d146ff57f3ac2ac279834dab432c993e8baf7b Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 2 Oct 2024 18:07:02 -0400 Subject: [PATCH 01/91] feat: setup derive macro --- Cargo.lock | 53 ++++++++++++++++--------- Cargo.toml | 2 +- rig-macros/Cargo.toml | 8 ++++ rig-macros/rig-macros-derive/Cargo.toml | 11 +++++ rig-macros/rig-macros-derive/src/lib.rs | 52 ++++++++++++++++++++++++ rig-macros/src/lib.rs | 34 ++++++++++++++++ 6 files changed, 141 insertions(+), 19 deletions(-) create mode 100644 rig-macros/Cargo.toml create mode 100644 rig-macros/rig-macros-derive/Cargo.toml create mode 100644 rig-macros/rig-macros-derive/src/lib.rs create mode 100644 rig-macros/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 80967e9d..2da34761 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -59,7 +59,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -458,7 +458,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -986,7 +986,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -1107,9 +1107,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -1226,6 +1226,22 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "rig-macros" +version = "0.1.0" +dependencies = [ + "rig-macros-derive", + "serde_json", +] + +[[package]] +name = "rig-macros-derive" +version = "0.1.0" +dependencies = [ + "quote", + "syn 2.0.79", +] + [[package]] name = "rig-mongodb" version = "0.1.2" @@ -1369,7 +1385,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -1458,7 +1474,7 @@ checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -1469,17 +1485,18 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] name = "serde_json" -version = "1.0.117" +version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ "indexmap", "itoa", + "memchr", "ryu", "serde", ] @@ -1635,9 +1652,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.65" +version = "2.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2863d96a84c6439701d7a38f9de935ec562c8832cc55d1dde0f513b52fad106" +checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" dependencies = [ "proc-macro2", "quote", @@ -1712,7 +1729,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -1798,7 +1815,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -1860,7 +1877,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -2068,7 +2085,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", "wasm-bindgen-shared", ] @@ -2102,7 +2119,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2341,5 +2358,5 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] diff --git a/Cargo.toml b/Cargo.toml index 2501b86f..d3a0c372 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" members = [ - "rig-core", + "rig-core", "rig-macros", "rig-macros/rig-macros-derive", "rig-mongodb", ] diff --git a/rig-macros/Cargo.toml b/rig-macros/Cargo.toml new file mode 100644 index 00000000..729582f1 --- /dev/null +++ b/rig-macros/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "rig-macros" +version = "0.1.0" +edition = "2021" + +[dependencies] +serde_json = "1.0.128" +rig-macros-derive = { path = "./rig-macros-derive" } diff --git a/rig-macros/rig-macros-derive/Cargo.toml b/rig-macros/rig-macros-derive/Cargo.toml new file mode 100644 index 00000000..51807c28 --- /dev/null +++ b/rig-macros/rig-macros-derive/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "rig-macros-derive" +version = "0.1.0" +edition = "2021" + +[dependencies] +quote = "1.0.37" +syn = "2.0.79" + +[lib] +proc-macro = true diff --git a/rig-macros/rig-macros-derive/src/lib.rs b/rig-macros/rig-macros-derive/src/lib.rs new file mode 100644 index 00000000..356546a8 --- /dev/null +++ b/rig-macros/rig-macros-derive/src/lib.rs @@ -0,0 +1,52 @@ +extern crate proc_macro; +use proc_macro::TokenStream; +use quote::quote; +use syn::{Attribute, Meta}; + +// https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro +// https://doc.rust-lang.org/reference/procedural-macros.html + +#[proc_macro_derive(Embedding, attributes(embed))] +pub fn derive_embed_trait(item: TokenStream) -> TokenStream { + let ast = syn::parse(item).unwrap(); + + impl_embeddable_macro(&ast) +} + +fn impl_embeddable_macro(ast: &syn::DeriveInput) -> TokenStream { + let name = &ast.ident; + + match &ast.data { + syn::Data::Struct(data_struct) => { + let field_to_embed = data_struct.fields.clone().into_iter().find(|field| { + field + .attrs + .clone() + .into_iter() + .find(|attribute| match attribute { + Attribute { + meta: Meta::Path(path), + .. + } => match path.get_ident() { + Some(attribute_name) => attribute_name == "embed", + None => false, + }, + _ => return false, + }) + .is_some() + }); + } + _ => {} + }; + + let gen = quote! { + impl Embeddable for #name { + type Kind = String; + + fn embeddable(&self) { + println!("{}", stringify!(#name)); + } + } + }; + gen.into() +} diff --git a/rig-macros/src/lib.rs b/rig-macros/src/lib.rs new file mode 100644 index 00000000..0512c177 --- /dev/null +++ b/rig-macros/src/lib.rs @@ -0,0 +1,34 @@ +enum Kind { + Single, + Many, +} + +trait Embeddable { + type Kind; + fn embeddable(&self); +} + +#[cfg(test)] +mod tests { + use super::Embeddable; + use rig_macros_derive::Embedding; + + #[derive(Embedding)] + struct MyStruct { + id: String, + #[embed] + name: String, + } + + #[test] + fn test_macro() { + let my_struct = MyStruct { + id: "1".to_string(), + name: "John".to_string(), + }; + + my_struct.embeddable(); + + assert!(false) + } +} From 5904734d4fe0f51a324fdea1465fee3683e26add Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 3 Oct 2024 16:13:16 -0400 Subject: [PATCH 02/91] test: test out writing embeddable macro --- rig-macros/rig-macros-derive/src/lib.rs | 61 ++++++---- rig-macros/src/lib.rs | 27 ++++- rig-macros/test.rs | 150 ++++++++++++++++++++++++ 3 files changed, 207 insertions(+), 31 deletions(-) create mode 100644 rig-macros/test.rs diff --git a/rig-macros/rig-macros-derive/src/lib.rs b/rig-macros/rig-macros-derive/src/lib.rs index 356546a8..2cf7f251 100644 --- a/rig-macros/rig-macros-derive/src/lib.rs +++ b/rig-macros/rig-macros-derive/src/lib.rs @@ -1,50 +1,59 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::quote; -use syn::{Attribute, Meta}; +use syn::{parse_macro_input, Attribute, DeriveInput, Meta}; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html #[proc_macro_derive(Embedding, attributes(embed))] pub fn derive_embed_trait(item: TokenStream) -> TokenStream { - let ast = syn::parse(item).unwrap(); + let input = parse_macro_input!(item as DeriveInput); - impl_embeddable_macro(&ast) + impl_embeddable_macro(&input) } -fn impl_embeddable_macro(ast: &syn::DeriveInput) -> TokenStream { - let name = &ast.ident; +fn impl_embeddable_macro(input: &syn::DeriveInput) -> TokenStream { + let name = &input.ident; - match &ast.data { + let embeddings = match &input.data { syn::Data::Struct(data_struct) => { - let field_to_embed = data_struct.fields.clone().into_iter().find(|field| { - field - .attrs - .clone() - .into_iter() - .find(|attribute| match attribute { - Attribute { - meta: Meta::Path(path), - .. - } => match path.get_ident() { - Some(attribute_name) => attribute_name == "embed", - None => false, - }, - _ => return false, - }) - .is_some() - }); + data_struct.fields.clone().into_iter().filter(|field| { + field + .attrs + .clone() + .into_iter() + .any(|attribute| match attribute { + Attribute { + meta: Meta::Path(path), + .. + } => match path.get_ident() { + Some(attribute_name) => attribute_name == "embed", + None => false + } + _ => false, + }) + }).map(|field| { + let field_name = field.ident.expect(""); + + quote! { + self.#field_name.embeddable() + } + + }).collect::<Vec<_>>() } - _ => {} + _ => vec![] }; let gen = quote! { impl Embeddable for #name { type Kind = String; - fn embeddable(&self) { - println!("{}", stringify!(#name)); + fn embeddable(&self) -> Vec<String> { + vec![ + #(#embeddings),* + ].into_iter().flatten().collect() + } } }; diff --git a/rig-macros/src/lib.rs b/rig-macros/src/lib.rs index 0512c177..fedf2cde 100644 --- a/rig-macros/src/lib.rs +++ b/rig-macros/src/lib.rs @@ -5,17 +5,34 @@ enum Kind { trait Embeddable { type Kind; - fn embeddable(&self); + fn embeddable(&self) -> Vec<String>; } #[cfg(test)] mod tests { - use super::Embeddable; + use super::{Embeddable, Kind}; use rig_macros_derive::Embedding; + impl Embeddable for usize { + type Kind = Kind; + + fn embeddable(&self) -> Vec<String> { + vec![self.to_string()] + } + } + + impl Embeddable for String { + type Kind = Kind; + + fn embeddable(&self) -> Vec<String> { + vec![self.clone()] + } + } + #[derive(Embedding)] struct MyStruct { - id: String, + #[embed] + id: usize, #[embed] name: String, } @@ -23,11 +40,11 @@ mod tests { #[test] fn test_macro() { let my_struct = MyStruct { - id: "1".to_string(), + id: 1, name: "John".to_string(), }; - my_struct.embeddable(); + println!("{:?}", my_struct.embeddable()); assert!(false) } diff --git a/rig-macros/test.rs b/rig-macros/test.rs new file mode 100644 index 00000000..4f9e80e9 --- /dev/null +++ b/rig-macros/test.rs @@ -0,0 +1,150 @@ +/// Builder for creating a collection of embeddings +pub struct EmbeddingsBuilder<M: EmbeddingModel, T: Embeddable, V: Serialize> { + model: M, + documents: Vec<(T, Vec<V>)>, +} + +trait Embeddable<V: Serialize> { + // Return list of strings that need to be embedded. + // Instead of Vec<String>, should be Vec<T: Serialize> + fn embeddable(&self) -> Vec<V>; +} + +type EmbeddingVector = Vec<f64>; + +impl<M: EmbeddingModel, T: Embeddable, V: Serialize> EmbeddingsBuilder<M, T, V> { + /// Create a new embedding builder with the given embedding model + pub fn new(model: M) -> Self { + Self { + model, + documents: vec![], + } + } + + pub fn add<T: Embeddable>( + mut self, + document: T, + ) -> Self { + let embed_documents: Vec<V> = document.embeddable(); + + self.documents.push(( + document, + embed_documents, + )); + self + } + + pub fn build(&self) -> Result<Vec<(T, Vec<EmbeddingVector>)>, EmbeddingError> { + self.documents.iter().map(|(doc, values_to_embed)| { + values_to_embed.iter().map(|value| { + let value_str = serde_json::to_string(value)?; + generate_embedding(value_str) + }) + }) + } + + pub fn build_simple(&self) -> Result<Vec<(T, EmbeddingVector)>, EmbeddingError> { + self.documents.iter().map(|(doc, value_to_embed)| { + let value_str = serde_json::to_string(value_to_embed)?; + generate_embedding(value_str) + }) + } +} + + +// Example +#[derive(Embeddable)] +struct DictionaryEntry { + word: String, + #[embed] + definitions: String, +} + +#[derive(Embeddable)] +struct MetadataEmbedding { + pub id: String, + #[embed(with = serde_json::to_value)] + pub content: CategoryMetadata, + pub created: Option<DateTime<Utc>>, + pub modified: Option<DateTime<Utc>>, + pub dataset_ids: Vec<String>, +} + +#[derive(serde::Serialize)] +struct CategoryMetadata { + pub name: String, + pub description: String, + pub tags: Vec<String>, + pub links: Vec<String>, +} + +// Inside macro: +impl Embeddable for DictionaryEntry { + fn embeddable(&self) -> Vec<String> { + // Find the field tagged with #[embed] and return its value + // If there are no embedding tags, return the entire struct + } +} + +fn main() { + let embeddings: Vec<(DictionaryEntry, Vec<EmbeddingVector>)> = EmbeddingsBuilder::new(model.clone()) + .add(DictionaryEntry::new("blah", vec!["definition of blah"])) + .add(DictionaryEntry::new("foo", vec!["definition of foo"])) + .build()?; + + // In relational vector store like LanceDB, need to flatten result (create row for each item in definitions vector): + // Column: word (string) + // Column: definition (vector) + + // In document vector store like MongoDB, might need to merge the vector results back with their corresponding definition string: + // Field: word (string) + // Field: definitions + // // Field: definition (string) + // // Field: vector + + Ok(()) +} + + + +// Iterations: +// 1 - Multiple fields to embed? +#[derive(Embedding)] +struct DictionaryEntry { + word: String, + #[embed] + definitions: Vec<String>, + #[embed] + synonyms: Vec<String> +} + +// 2 - Embed recursion? Ex: +#[derive(Embedding)] +struct DictionaryEntry { + word: String, + #[embed] + definitions: Vec<Definition>, +} +struct Definition { + definition: String, + #[embed] + links: Vec<String> +} + +// { +// word: "blah", +// definitions: [ +// { +// definition: "definition of blah", +// links: ["link1", "link2"] +// }, +// { +// definition: "another definition for blah", +// links: ["link3"] +// } +// ] +// } + +// blah | definition of blah | link1 | embedding for link1 +// blah | definition of blah | link2 | embedding for link2 +// blah | another definition for blah | link3 | embedding for link3 \ No newline at end of file From ee9b5c332b322621395757ef4708e010776a8362 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 4 Oct 2024 17:23:39 -0400 Subject: [PATCH 03/91] test: continue testing custom macro implementation --- Cargo.lock | 11 +- rig-macros/rig-macros-derive/Cargo.toml | 3 +- rig-macros/rig-macros-derive/src/lib.rs | 129 ++++++++++++++++++------ rig-macros/src/lib.rs | 6 +- rig-macros/test.rs | 12 +-- 5 files changed, 119 insertions(+), 42 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2da34761..e62e3101 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -712,6 +712,12 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + [[package]] name = "ipconfig" version = "0.3.2" @@ -1092,9 +1098,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.83" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b33eb56c327dec362a9e55b3ad14f9d2f0904fb5a5b03b513ab5465399e9f43" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" dependencies = [ "unicode-ident", ] @@ -1238,6 +1244,7 @@ dependencies = [ name = "rig-macros-derive" version = "0.1.0" dependencies = [ + "indoc", "quote", "syn 2.0.79", ] diff --git a/rig-macros/rig-macros-derive/Cargo.toml b/rig-macros/rig-macros-derive/Cargo.toml index 51807c28..f2f22f6d 100644 --- a/rig-macros/rig-macros-derive/Cargo.toml +++ b/rig-macros/rig-macros-derive/Cargo.toml @@ -4,8 +4,9 @@ version = "0.1.0" edition = "2021" [dependencies] +indoc = "2.0.5" quote = "1.0.37" -syn = "2.0.79" +syn = { version = "2.0.79", features = ["full"]} [lib] proc-macro = true diff --git a/rig-macros/rig-macros-derive/src/lib.rs b/rig-macros/rig-macros-derive/src/lib.rs index 2cf7f251..f9a5bd57 100644 --- a/rig-macros/rig-macros-derive/src/lib.rs +++ b/rig-macros/rig-macros-derive/src/lib.rs @@ -1,11 +1,17 @@ extern crate proc_macro; +use indoc::indoc; use proc_macro::TokenStream; -use quote::quote; -use syn::{parse_macro_input, Attribute, DeriveInput, Meta}; +use quote::{quote, ToTokens}; +use syn::{ + meta::ParseNestedMeta, parse_macro_input, spanned::Spanned, Attribute, DataStruct, DeriveInput, Meta, ExprPath +}; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html +const EMBED: &str = "embed"; +const EMBED_WITH: &str = "embed_with"; + #[proc_macro_derive(Embedding, attributes(embed))] pub fn derive_embed_trait(item: TokenStream) -> TokenStream { let input = parse_macro_input!(item as DeriveInput); @@ -18,44 +24,105 @@ fn impl_embeddable_macro(input: &syn::DeriveInput) -> TokenStream { let embeddings = match &input.data { syn::Data::Struct(data_struct) => { - data_struct.fields.clone().into_iter().filter(|field| { - field - .attrs - .clone() - .into_iter() - .any(|attribute| match attribute { - Attribute { - meta: Meta::Path(path), - .. - } => match path.get_ident() { - Some(attribute_name) => attribute_name == "embed", - None => false - } - _ => false, - }) - }).map(|field| { - let field_name = field.ident.expect(""); - - quote! { - self.#field_name.embeddable() - } - - }).collect::<Vec<_>>() + // let invoke_trait = invoke_trait(data_struct) + // .map(|field_name| { + // quote! { + // self.#field_name.embeddable() + // } + // }) + // .collect::<Vec<_>>(); + custom_trait_implementation(data_struct) } - _ => vec![] - }; + _ => Ok(false), + } + .unwrap(); let gen = quote! { impl Embeddable for #name { type Kind = String; fn embeddable(&self) -> Vec<String> { - vec![ - #(#embeddings),* - ].into_iter().flatten().collect() - + // vec![ + // #(#embeddings),* + // ].into_iter().flatten().collect() + println!("{}", #embeddings); + vec![] } } }; gen.into() } + +fn custom_trait_implementation(data_struct: &DataStruct) -> Result<bool, syn::Error> { + let t = data_struct + .fields + .clone() + .into_iter() + .for_each(|field| { + let _t = field.attrs.clone().into_iter().map(|attr| { + let t = if attr.path().is_ident(EMBED) { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident(EMBED_WITH) { + let path = parse_embed_with(&meta)?; + + let tokens = meta.path.into_token_stream(); + }; + Ok(()) + }) + } else { + todo!() + }; + }).collect::<Vec<_>>(); + }); + Ok(false) +} + +fn parse_embed_with(meta: &ParseNestedMeta) -> Result<ExprPath, syn::Error> { + // #[embed(embed_with = "...")] + let expr = meta.value().unwrap().parse::<syn::Expr>().unwrap(); + let mut value = &expr; + while let syn::Expr::Group(e) = value { + value = &e.expr; + } + let string = if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit_str), + .. + }) = value + { + let suffix = lit_str.suffix(); + if !suffix.is_empty() { + return Err(syn::Error::new( + lit_str.span(), + format!("unexpected suffix `{}` on string literal", suffix) + )) + } + lit_str.clone() + } else { + return Err(syn::Error::new( + value.span(), + format!("expected {} attribute to be a string: `{} = \"...\"`", EMBED_WITH, EMBED_WITH) + )) + }; + + string.parse() +} + +fn invoke_trait(data_struct: &DataStruct) -> impl Iterator<Item = syn::Ident> { + data_struct.fields.clone().into_iter().filter_map(|field| { + let found_embed = field + .attrs + .clone() + .into_iter() + .any(|attribute| match attribute { + Attribute { + meta: Meta::Path(path), + .. + } => path.is_ident("embed"), + _ => false, + }); + match found_embed { + true => Some(field.ident.expect("")), + false => None, + } + }) +} diff --git a/rig-macros/src/lib.rs b/rig-macros/src/lib.rs index fedf2cde..628fe6d2 100644 --- a/rig-macros/src/lib.rs +++ b/rig-macros/src/lib.rs @@ -33,7 +33,7 @@ mod tests { struct MyStruct { #[embed] id: usize, - #[embed] + #[embed(embed_with = "something")] name: String, } @@ -44,7 +44,9 @@ mod tests { name: "John".to_string(), }; - println!("{:?}", my_struct.embeddable()); + my_struct.embeddable(); + + // println!("{:?}", my_struct.embeddable()); assert!(false) } diff --git a/rig-macros/test.rs b/rig-macros/test.rs index 4f9e80e9..f230af56 100644 --- a/rig-macros/test.rs +++ b/rig-macros/test.rs @@ -1,18 +1,18 @@ /// Builder for creating a collection of embeddings -pub struct EmbeddingsBuilder<M: EmbeddingModel, T: Embeddable, V: Serialize> { +pub struct EmbeddingsBuilder<M: EmbeddingModel, T: Embeddable> { model: M, - documents: Vec<(T, Vec<V>)>, + documents: Vec<(T, Vec<String>)>, } -trait Embeddable<V: Serialize> { +trait Embeddable { // Return list of strings that need to be embedded. // Instead of Vec<String>, should be Vec<T: Serialize> - fn embeddable(&self) -> Vec<V>; + fn embeddable(&self) -> Vec<String>; } type EmbeddingVector = Vec<f64>; -impl<M: EmbeddingModel, T: Embeddable, V: Serialize> EmbeddingsBuilder<M, T, V> { +impl<M: EmbeddingModel, T: Embeddable> EmbeddingsBuilder<M, T> { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { @@ -25,7 +25,7 @@ impl<M: EmbeddingModel, T: Embeddable, V: Serialize> EmbeddingsBuilder<M, T, V> mut self, document: T, ) -> Self { - let embed_documents: Vec<V> = document.embeddable(); + let embed_documents: Vec<String> = document.embeddable(); self.documents.push(( document, From c8c1e9ca0c38bba8fc1388f07377e4526f549963 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Sat, 5 Oct 2024 23:24:02 -0400 Subject: [PATCH 04/91] feat: macro generate trait bounds --- rig-macros/rig-macros-derive/src/lib.rs | 74 +++++++++++++++++-------- rig-macros/src/lib.rs | 8 --- 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/rig-macros/rig-macros-derive/src/lib.rs b/rig-macros/rig-macros-derive/src/lib.rs index f9a5bd57..2a6e4aaa 100644 --- a/rig-macros/rig-macros-derive/src/lib.rs +++ b/rig-macros/rig-macros-derive/src/lib.rs @@ -3,7 +3,7 @@ use indoc::indoc; use proc_macro::TokenStream; use quote::{quote, ToTokens}; use syn::{ - meta::ParseNestedMeta, parse_macro_input, spanned::Spanned, Attribute, DataStruct, DeriveInput, Meta, ExprPath + meta::ParseNestedMeta, parse_macro_input, parse_quote, spanned::Spanned, Attribute, DataStruct, DeriveInput, ExprPath, Meta, Path }; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro @@ -14,31 +14,39 @@ const EMBED_WITH: &str = "embed_with"; #[proc_macro_derive(Embedding, attributes(embed))] pub fn derive_embed_trait(item: TokenStream) -> TokenStream { - let input = parse_macro_input!(item as DeriveInput); + let mut input = parse_macro_input!(item as DeriveInput); - impl_embeddable_macro(&input) + impl_embeddable_macro(&mut input) } -fn impl_embeddable_macro(input: &syn::DeriveInput) -> TokenStream { +fn impl_embeddable_macro(input: &mut syn::DeriveInput) -> TokenStream { let name = &input.ident; let embeddings = match &input.data { syn::Data::Struct(data_struct) => { - // let invoke_trait = invoke_trait(data_struct) + basic_embed_fields(data_struct).for_each(|(_, field_type)| { + add_struct_bounds(&mut input.generics, &field_type) + }); + + // let basic_embed_fields = basic_embed_fields(data_struct) // .map(|field_name| { // quote! { // self.#field_name.embeddable() // } // }) // .collect::<Vec<_>>(); - custom_trait_implementation(data_struct) + + let func_names = custom_trait_implementation(data_struct).unwrap(); + + false } - _ => Ok(false), - } - .unwrap(); + _ => false, + }; + + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { - impl Embeddable for #name { + impl #impl_generics Embeddable for #name #ty_generics #where_clause { type Kind = String; fn embeddable(&self) -> Vec<String> { @@ -50,31 +58,37 @@ fn impl_embeddable_macro(input: &syn::DeriveInput) -> TokenStream { } } }; + eprintln!("Generated code:\n{}", gen); + gen.into() } -fn custom_trait_implementation(data_struct: &DataStruct) -> Result<bool, syn::Error> { - let t = data_struct +fn custom_trait_implementation(data_struct: &DataStruct) -> Result<Vec<ExprPath>, syn::Error> { + Ok(data_struct .fields .clone() .into_iter() - .for_each(|field| { - let _t = field.attrs.clone().into_iter().map(|attr| { - let t = if attr.path().is_ident(EMBED) { + .map(|field| { + let mut path = None; + field.attrs.clone().into_iter().map(|attr| { + if attr.path().is_ident(EMBED) { attr.parse_nested_meta(|meta| { if meta.path.is_ident(EMBED_WITH) { - let path = parse_embed_with(&meta)?; + path = Some(parse_embed_with(&meta)?); - let tokens = meta.path.into_token_stream(); + // let tokens = meta.path.into_token_stream(); }; Ok(()) }) } else { - todo!() - }; - }).collect::<Vec<_>>(); - }); - Ok(false) + Ok(()) + } + }).collect::<Result<Vec<_>,_>>()?; + Ok::<_, syn::Error>(path) + }).collect::<Result<Vec<_>,_>>()? + .into_iter() + .filter_map(|i| i) + .collect()) } fn parse_embed_with(meta: &ParseNestedMeta) -> Result<ExprPath, syn::Error> { @@ -107,7 +121,16 @@ fn parse_embed_with(meta: &ParseNestedMeta) -> Result<ExprPath, syn::Error> { string.parse() } -fn invoke_trait(data_struct: &DataStruct) -> impl Iterator<Item = syn::Ident> { +fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { + let where_clause = generics.make_where_clause(); + + where_clause.predicates.push(parse_quote! { + #field_type: Embeddable + }); +} + + +fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item = (syn::Ident, syn::Type)> { data_struct.fields.clone().into_iter().filter_map(|field| { let found_embed = field .attrs @@ -121,7 +144,10 @@ fn invoke_trait(data_struct: &DataStruct) -> impl Iterator<Item = syn::Ident> { _ => false, }); match found_embed { - true => Some(field.ident.expect("")), + true => Some(( + field.ident.expect(""), + field.ty + )), false => None, } }) diff --git a/rig-macros/src/lib.rs b/rig-macros/src/lib.rs index 628fe6d2..b57bf150 100644 --- a/rig-macros/src/lib.rs +++ b/rig-macros/src/lib.rs @@ -21,14 +21,6 @@ mod tests { } } - impl Embeddable for String { - type Kind = Kind; - - fn embeddable(&self) -> Vec<String> { - vec![self.clone()] - } - } - #[derive(Embedding)] struct MyStruct { #[embed] From dff0aebb32986ae3df32ce748987c3272e6c9aa6 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Mon, 7 Oct 2024 14:51:07 -0400 Subject: [PATCH 05/91] refactor: split up macro into multiple files --- Cargo.lock | 9 +- rig-macros/Cargo.toml | 1 + rig-macros/rig-macros-derive/src/embedding.rs | 216 ++++++++++++++++++ rig-macros/rig-macros-derive/src/lib.rs | 147 +----------- rig-macros/src/lib.rs | 32 ++- rig-macros/test.rs | 39 ++-- 6 files changed, 275 insertions(+), 169 deletions(-) create mode 100644 rig-macros/rig-macros-derive/src/embedding.rs diff --git a/Cargo.lock b/Cargo.lock index e62e3101..48f46588 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1237,6 +1237,7 @@ name = "rig-macros" version = "0.1.0" dependencies = [ "rig-macros-derive", + "serde", "serde_json", ] @@ -1457,9 +1458,9 @@ checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" [[package]] name = "serde" -version = "1.0.203" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" dependencies = [ "serde_derive", ] @@ -1475,9 +1476,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.203" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" dependencies = [ "proc-macro2", "quote", diff --git a/rig-macros/Cargo.toml b/rig-macros/Cargo.toml index 729582f1..d1cd2cd9 100644 --- a/rig-macros/Cargo.toml +++ b/rig-macros/Cargo.toml @@ -6,3 +6,4 @@ edition = "2021" [dependencies] serde_json = "1.0.128" rig-macros-derive = { path = "./rig-macros-derive" } +serde = {version = "1.0.210", features = ["derive"]} diff --git a/rig-macros/rig-macros-derive/src/embedding.rs b/rig-macros/rig-macros-derive/src/embedding.rs new file mode 100644 index 00000000..2bfdb0ab --- /dev/null +++ b/rig-macros/rig-macros-derive/src/embedding.rs @@ -0,0 +1,216 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{ + meta::ParseNestedMeta, parse_quote, punctuated::Punctuated, spanned::Spanned, Attribute, + DataStruct, ExprPath, Meta, Token, +}; + +const EMBED: &str = "embed"; +const EMBED_WITH: &str = "embed_with"; + +pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { + let name = &input.ident; + + let func_calls = + match &input.data { + syn::Data::Struct(data_struct) => { + // Handles fields tagged with #[embed] + let mut function_calls = data_struct + .basic_embed_fields() + .map(|field| { + add_struct_bounds(&mut input.generics, &field.ty); + + let field_name = field.ident; + quote! { + self.#field_name.embeddable() + } + }) + .collect::<Vec<_>>(); + + // Handles fields tagged with #[embed(embed_with = "...")] + function_calls.extend(data_struct.custom_embed_fields().unwrap().map( + |(field, _)| { + let field_name = field.ident; + + quote! { + embeddable(&self.#field_name) + } + }, + )); + + function_calls + } + _ => vec![], + }; + + // Import the paths to the custom functions. + let custom_func_paths = match &input.data { + syn::Data::Struct(data_struct) => data_struct + .custom_embed_fields() + .unwrap() + .map(|(_, custom_func_path)| { + quote! { + use #custom_func_path::embeddable; + } + }) + .collect::<Vec<_>>(), + _ => vec![], + }; + + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + + let gen = quote! { + #(#custom_func_paths);* + + impl #impl_generics Embeddable for #name #ty_generics #where_clause { + type Kind = String; + + fn embeddable(&self) -> Vec<String> { + vec![ + #(#func_calls),* + ].into_iter().flatten().collect() + } + } + }; + eprintln!("Generated code:\n{}", gen); + + gen.into() +} + +// Adds bounds to where clause that force all fields tagged with #[embed] to implement the Embeddable trait. +fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { + let where_clause = generics.make_where_clause(); + + where_clause.predicates.push(parse_quote! { + #field_type: Embeddable + }); +} + +trait AttributeParser { + /// Finds and returns fields with simple #[embed] attribute tags only. + fn basic_embed_fields(&self) -> impl Iterator<Item = syn::Field>; + /// Finds and returns fields with #[embed(embed_with = "...")] attribute tags only. + /// Also returns the attribute in question. + fn custom_embed_fields( + &self, + ) -> Result<impl Iterator<Item = (syn::Field, syn::ExprPath)>, syn::Error>; +} + +impl AttributeParser for DataStruct { + fn basic_embed_fields(&self) -> impl Iterator<Item = syn::Field> { + self.fields.clone().into_iter().filter(|field| { + field + .attrs + .clone() + .into_iter() + .any(|attribute| match attribute { + Attribute { + meta: Meta::Path(path), + .. + } => path.is_ident(EMBED), + _ => false, + }) + }) + } + + fn custom_embed_fields( + &self, + ) -> Result<impl Iterator<Item = (syn::Field, syn::ExprPath)>, syn::Error> { + // Determine if field is tagged with #[embed(embed_with = "...")] attribute. + fn is_custom_embed(attribute: &syn::Attribute) -> Result<bool, syn::Error> { + let is_custom_embed = match attribute.meta { + Meta::List(_) => attribute + .parse_args_with(Punctuated::<Meta, Token![=]>::parse_terminated)? + .into_iter() + .any(|meta| meta.path().is_ident(EMBED_WITH)), + _ => false, + }; + + Ok(attribute.path().is_ident(EMBED) && is_custom_embed) + } + + // Get the "..." part of the #[embed(embed_with = "...")] attribute. + // Ex: If attribute is tagged with #[embed(embed_with = "my_embed")], returns "my_embed". + fn expand_tag(attribute: &syn::Attribute) -> Result<syn::ExprPath, syn::Error> { + let mut custom_func_path = None; + + attribute.parse_nested_meta(|meta| { + custom_func_path = Some(meta.function_path()?); + Ok(()) + })?; + + match custom_func_path { + Some(path) => Ok(path), + None => Err(syn::Error::new( + attribute.span(), + format!( + "expected {} attribute to have format: `#[embed(embed_with = \"...\")]`", + EMBED_WITH + ), + )), + } + } + + Ok(self + .fields + .clone() + .into_iter() + .map(|field| { + field + .attrs + .clone() + .into_iter() + .map(|attribute| { + if is_custom_embed(&attribute)? { + Ok::<_, syn::Error>(Some((field.clone(), expand_tag(&attribute)?))) + } else { + Ok(None) + } + }) + .collect::<Result<Vec<_>, _>>() + }) + .collect::<Result<Vec<_>, _>>()? + .into_iter() + .flatten() + .flatten()) + } +} + +trait CustomFunction { + fn function_path(&self) -> Result<ExprPath, syn::Error>; +} + +impl CustomFunction for ParseNestedMeta<'_> { + fn function_path(&self) -> Result<ExprPath, syn::Error> { + // #[embed(embed_with = "...")] + let expr = self.value().unwrap().parse::<syn::Expr>().unwrap(); + let mut value = &expr; + while let syn::Expr::Group(e) = value { + value = &e.expr; + } + let string = if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit_str), + .. + }) = value + { + let suffix = lit_str.suffix(); + if !suffix.is_empty() { + return Err(syn::Error::new( + lit_str.span(), + format!("unexpected suffix `{}` on string literal", suffix), + )); + } + lit_str.clone() + } else { + return Err(syn::Error::new( + value.span(), + format!( + "expected {} attribute to be a string: `{} = \"...\"`", + EMBED_WITH, EMBED_WITH + ), + )); + }; + + string.parse() + } +} diff --git a/rig-macros/rig-macros-derive/src/lib.rs b/rig-macros/rig-macros-derive/src/lib.rs index 2a6e4aaa..239d9b3a 100644 --- a/rig-macros/rig-macros-derive/src/lib.rs +++ b/rig-macros/rig-macros-derive/src/lib.rs @@ -1,154 +1,15 @@ extern crate proc_macro; -use indoc::indoc; use proc_macro::TokenStream; -use quote::{quote, ToTokens}; -use syn::{ - meta::ParseNestedMeta, parse_macro_input, parse_quote, spanned::Spanned, Attribute, DataStruct, DeriveInput, ExprPath, Meta, Path -}; +use syn::{parse_macro_input, DeriveInput}; + +mod embedding; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html -const EMBED: &str = "embed"; -const EMBED_WITH: &str = "embed_with"; - #[proc_macro_derive(Embedding, attributes(embed))] pub fn derive_embed_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); - impl_embeddable_macro(&mut input) -} - -fn impl_embeddable_macro(input: &mut syn::DeriveInput) -> TokenStream { - let name = &input.ident; - - let embeddings = match &input.data { - syn::Data::Struct(data_struct) => { - basic_embed_fields(data_struct).for_each(|(_, field_type)| { - add_struct_bounds(&mut input.generics, &field_type) - }); - - // let basic_embed_fields = basic_embed_fields(data_struct) - // .map(|field_name| { - // quote! { - // self.#field_name.embeddable() - // } - // }) - // .collect::<Vec<_>>(); - - let func_names = custom_trait_implementation(data_struct).unwrap(); - - false - } - _ => false, - }; - - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - - let gen = quote! { - impl #impl_generics Embeddable for #name #ty_generics #where_clause { - type Kind = String; - - fn embeddable(&self) -> Vec<String> { - // vec![ - // #(#embeddings),* - // ].into_iter().flatten().collect() - println!("{}", #embeddings); - vec![] - } - } - }; - eprintln!("Generated code:\n{}", gen); - - gen.into() -} - -fn custom_trait_implementation(data_struct: &DataStruct) -> Result<Vec<ExprPath>, syn::Error> { - Ok(data_struct - .fields - .clone() - .into_iter() - .map(|field| { - let mut path = None; - field.attrs.clone().into_iter().map(|attr| { - if attr.path().is_ident(EMBED) { - attr.parse_nested_meta(|meta| { - if meta.path.is_ident(EMBED_WITH) { - path = Some(parse_embed_with(&meta)?); - - // let tokens = meta.path.into_token_stream(); - }; - Ok(()) - }) - } else { - Ok(()) - } - }).collect::<Result<Vec<_>,_>>()?; - Ok::<_, syn::Error>(path) - }).collect::<Result<Vec<_>,_>>()? - .into_iter() - .filter_map(|i| i) - .collect()) -} - -fn parse_embed_with(meta: &ParseNestedMeta) -> Result<ExprPath, syn::Error> { - // #[embed(embed_with = "...")] - let expr = meta.value().unwrap().parse::<syn::Expr>().unwrap(); - let mut value = &expr; - while let syn::Expr::Group(e) = value { - value = &e.expr; - } - let string = if let syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(lit_str), - .. - }) = value - { - let suffix = lit_str.suffix(); - if !suffix.is_empty() { - return Err(syn::Error::new( - lit_str.span(), - format!("unexpected suffix `{}` on string literal", suffix) - )) - } - lit_str.clone() - } else { - return Err(syn::Error::new( - value.span(), - format!("expected {} attribute to be a string: `{} = \"...\"`", EMBED_WITH, EMBED_WITH) - )) - }; - - string.parse() -} - -fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { - let where_clause = generics.make_where_clause(); - - where_clause.predicates.push(parse_quote! { - #field_type: Embeddable - }); -} - - -fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item = (syn::Ident, syn::Type)> { - data_struct.fields.clone().into_iter().filter_map(|field| { - let found_embed = field - .attrs - .clone() - .into_iter() - .any(|attribute| match attribute { - Attribute { - meta: Meta::Path(path), - .. - } => path.is_ident("embed"), - _ => false, - }); - match found_embed { - true => Some(( - field.ident.expect(""), - field.ty - )), - false => None, - } - }) + embedding::expand_derive_embedding(&mut input) } diff --git a/rig-macros/src/lib.rs b/rig-macros/src/lib.rs index b57bf150..d7151f03 100644 --- a/rig-macros/src/lib.rs +++ b/rig-macros/src/lib.rs @@ -8,8 +8,24 @@ trait Embeddable { fn embeddable(&self) -> Vec<String>; } +#[derive(serde::Serialize)] +pub struct JobStruct { + job_title: String, + company: String, +} + +mod something { + use super::JobStruct; + + pub fn embeddable(input: &JobStruct) -> Vec<String> { + vec![serde_json::to_string(input).unwrap()] + } +} + #[cfg(test)] mod tests { + use crate::JobStruct; + use super::{Embeddable, Kind}; use rig_macros_derive::Embedding; @@ -22,23 +38,27 @@ mod tests { } #[derive(Embedding)] - struct MyStruct { + struct SomeStruct { #[embed] id: usize, - #[embed(embed_with = "something")] name: String, + #[embed(embed_with = "super::something")] + job: JobStruct, } #[test] fn test_macro() { - let my_struct = MyStruct { + let job_struct = JobStruct { + job_title: "developer".to_string(), + company: "playgrounds".to_string(), + }; + let some_struct = SomeStruct { id: 1, name: "John".to_string(), + job: job_struct, }; - my_struct.embeddable(); - - // println!("{:?}", my_struct.embeddable()); + println!("{:?}", some_struct.embeddable()); assert!(false) } diff --git a/rig-macros/test.rs b/rig-macros/test.rs index f230af56..225a9b13 100644 --- a/rig-macros/test.rs +++ b/rig-macros/test.rs @@ -5,6 +5,7 @@ pub struct EmbeddingsBuilder<M: EmbeddingModel, T: Embeddable> { } trait Embeddable { + type Kind; // Return list of strings that need to be embedded. // Instead of Vec<String>, should be Vec<T: Serialize> fn embeddable(&self) -> Vec<String>; @@ -12,6 +13,28 @@ trait Embeddable { type EmbeddingVector = Vec<f64>; +impl<M: EmbeddingModel, T: Embeddable<Kind = Single>> EmbeddingsBuilder<M, T> { + pub fn build(&self) -> Result<Vec<(T, EmbeddingVector)>, EmbeddingError> { + self.documents.iter().map(|(doc, values_to_embed)| { + values_to_embed.iter().map(|value| { + let value_str = serde_json::to_string(value)?; + generate_embedding(value_str) + }) + }) + } +} + +impl<M: EmbeddingModel, T: Embeddable<Kind = Many>> EmbeddingsBuilder<M, T> { + pub fn build(&self) -> Result<Vec<(T, Vec<EmbeddingVector>)>, EmbeddingError> { + self.documents.iter().map(|(doc, values_to_embed)| { + values_to_embed.iter().map(|value| { + let value_str = serde_json::to_string(value)?; + generate_embedding(value_str) + }) + }) + } +} + impl<M: EmbeddingModel, T: Embeddable> EmbeddingsBuilder<M, T> { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { @@ -33,22 +56,6 @@ impl<M: EmbeddingModel, T: Embeddable> EmbeddingsBuilder<M, T> { )); self } - - pub fn build(&self) -> Result<Vec<(T, Vec<EmbeddingVector>)>, EmbeddingError> { - self.documents.iter().map(|(doc, values_to_embed)| { - values_to_embed.iter().map(|value| { - let value_str = serde_json::to_string(value)?; - generate_embedding(value_str) - }) - }) - } - - pub fn build_simple(&self) -> Result<Vec<(T, EmbeddingVector)>, EmbeddingError> { - self.documents.iter().map(|(doc, value_to_embed)| { - let value_str = serde_json::to_string(value_to_embed)?; - generate_embedding(value_str) - }) - } } From 0d1001119a3a850cdfb189e4cecaeda20bc9d1da Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Mon, 7 Oct 2024 16:56:29 -0400 Subject: [PATCH 06/91] refactor: move macro derive crate inside rig-core --- Cargo.lock | 12 +- Cargo.toml | 2 +- rig-core/Cargo.toml | 1 + .../rig-core-derive}/Cargo.toml | 2 +- .../rig-core-derive}/src/embedding.rs | 0 .../rig-core-derive}/src/lib.rs | 2 +- rig-core/src/embeddings.rs | 223 +++++++----------- rig-macros/Cargo.toml | 9 - rig-macros/src/lib.rs | 65 ----- rig-macros/test.rs | 157 ------------ 10 files changed, 87 insertions(+), 386 deletions(-) rename {rig-macros/rig-macros-derive => rig-core/rig-core-derive}/Cargo.toml (86%) rename {rig-macros/rig-macros-derive => rig-core/rig-core-derive}/src/embedding.rs (100%) rename {rig-macros/rig-macros-derive => rig-core/rig-core-derive}/src/lib.rs (86%) delete mode 100644 rig-macros/Cargo.toml delete mode 100644 rig-macros/src/lib.rs delete mode 100644 rig-macros/test.rs diff --git a/Cargo.lock b/Cargo.lock index 48f46588..ced6668c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1223,6 +1223,7 @@ dependencies = [ "futures", "ordered-float", "reqwest", + "rig-core-derive", "schemars", "serde", "serde_json", @@ -1233,16 +1234,7 @@ dependencies = [ ] [[package]] -name = "rig-macros" -version = "0.1.0" -dependencies = [ - "rig-macros-derive", - "serde", - "serde_json", -] - -[[package]] -name = "rig-macros-derive" +name = "rig-core-derive" version = "0.1.0" dependencies = [ "indoc", diff --git a/Cargo.toml b/Cargo.toml index d3a0c372..a37ba0e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" members = [ - "rig-core", "rig-macros", "rig-macros/rig-macros-derive", + "rig-core", "rig-core/rig-core-derive", "rig-mongodb", ] diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 28561465..be83f04f 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -22,6 +22,7 @@ futures = "0.3.29" ordered-float = "4.2.0" schemars = "0.8.16" thiserror = "1.0.61" +rig-core-derive = { path = "./rig-core-derive" } [dev-dependencies] anyhow = "1.0.75" diff --git a/rig-macros/rig-macros-derive/Cargo.toml b/rig-core/rig-core-derive/Cargo.toml similarity index 86% rename from rig-macros/rig-macros-derive/Cargo.toml rename to rig-core/rig-core-derive/Cargo.toml index f2f22f6d..8f8a45d7 100644 --- a/rig-macros/rig-macros-derive/Cargo.toml +++ b/rig-core/rig-core-derive/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "rig-macros-derive" +name = "rig-core-derive" version = "0.1.0" edition = "2021" diff --git a/rig-macros/rig-macros-derive/src/embedding.rs b/rig-core/rig-core-derive/src/embedding.rs similarity index 100% rename from rig-macros/rig-macros-derive/src/embedding.rs rename to rig-core/rig-core-derive/src/embedding.rs diff --git a/rig-macros/rig-macros-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs similarity index 86% rename from rig-macros/rig-macros-derive/src/lib.rs rename to rig-core/rig-core-derive/src/lib.rs index 239d9b3a..eb7abaa9 100644 --- a/rig-macros/rig-macros-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -8,7 +8,7 @@ mod embedding; // https://doc.rust-lang.org/reference/procedural-macros.html #[proc_macro_derive(Embedding, attributes(embed))] -pub fn derive_embed_trait(item: TokenStream) -> TokenStream { +pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); embedding::expand_derive_embedding(&mut input) diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index 2d40bbc5..c72255df 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -147,15 +147,21 @@ pub struct DocumentEmbeddings { pub embeddings: Vec<Embedding>, } -type Embeddings = Vec<DocumentEmbeddings>; +struct SingleEmbedding; +struct ManyEmbedding; + +pub trait Embeddable { + type Kind; + fn embeddable(&self) -> Vec<String>; +} /// Builder for creating a collection of embeddings -pub struct EmbeddingsBuilder<M: EmbeddingModel> { +pub struct EmbeddingsBuilder<M: EmbeddingModel, D: Embeddable + Send> { model: M, - documents: Vec<(String, serde_json::Value, Vec<String>)>, + documents: Vec<(D, Vec<String>)>, } -impl<M: EmbeddingModel> EmbeddingsBuilder<M> { +impl<M: EmbeddingModel, D: Embeddable + Send> EmbeddingsBuilder<M, D> { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { @@ -164,169 +170,102 @@ impl<M: EmbeddingModel> EmbeddingsBuilder<M> { } } - /// Add a simple document to the embedding collection. - /// The provided document string will be used for the embedding. - pub fn simple_document(mut self, id: &str, document: &str) -> Self { - self.documents.push(( - id.to_string(), - serde_json::Value::String(document.to_string()), - vec![document.to_string()], - )); - self - } - - /// Add multiple documents to the embedding collection. - /// Each element of the vector is a tuple of the form (id, document). - pub fn simple_documents(mut self, documents: Vec<(String, String)>) -> Self { - self.documents - .extend(documents.into_iter().map(|(id, document)| { - ( - id, - serde_json::Value::String(document.clone()), - vec![document], - ) - })); - self - } - - /// Add a tool to the embedding collection. - /// The `tool.context()` corresponds to the document being stored while - /// `tool.embedding_docs()` corresponds to the documents that will be used to generate the embeddings. - pub fn tool(mut self, tool: impl ToolEmbedding + 'static) -> Result<Self, EmbeddingError> { - self.documents.push(( - tool.name(), - serde_json::to_value(tool.context())?, - tool.embedding_docs(), - )); - Ok(self) - } - - /// Add the tools from the given toolset to the embedding collection. - pub fn tools(mut self, toolset: &ToolSet) -> Result<Self, EmbeddingError> { - for (name, tool) in toolset.tools.iter() { - if let ToolType::Embedding(tool) = tool { - self.documents.push(( - name.clone(), - tool.context().map_err(|e| { - EmbeddingError::DocumentError(format!( - "Failed to generate context for tool {}: {}", - name, e - )) - })?, - tool.embedding_docs(), - )); - } - } - Ok(self) - } - - /// Add a document to the embedding collection. - /// `embed_documents` are the documents that will be used to generate the embeddings - /// for `document`. - pub fn document<T: Serialize>( - mut self, - id: &str, - document: T, - embed_documents: Vec<String>, - ) -> Self { - self.documents.push(( - id.to_string(), - serde_json::to_value(document).expect("Document should serialize"), - embed_documents, - )); - self - } + pub fn document(mut self, document: D) -> Self { + let embed_targets = document.embeddable(); - /// Add multiple documents to the embedding collection. - /// Each element of the vector is a tuple of the form (id, document, embed_documents). - pub fn documents<T: Serialize>(mut self, documents: Vec<(String, T, Vec<String>)>) -> Self { - self.documents.extend( - documents - .into_iter() - .map(|(id, document, embed_documents)| { - ( - id, - serde_json::to_value(document).expect("Document should serialize"), - embed_documents, - ) - }), - ); - self - } - - /// Add a json document to the embedding collection. - pub fn json_document( - mut self, - id: &str, - document: serde_json::Value, - embed_documents: Vec<String>, - ) -> Self { - self.documents - .push((id.to_string(), document, embed_documents)); - self - } - - /// Add multiple json documents to the embedding collection. - pub fn json_documents( - mut self, - documents: Vec<(String, serde_json::Value, Vec<String>)>, - ) -> Self { - self.documents.extend(documents); + self.documents.push((document, embed_targets)); self } +} - /// Generate the embeddings for the given documents - pub async fn build(self) -> Result<Embeddings, EmbeddingError> { - // Create a temporary store for the documents +impl<M: EmbeddingModel, D: Embeddable<Kind = ManyEmbedding> + Send + Clone> + EmbeddingsBuilder<M, D> +{ + pub async fn build(self) -> Result<Vec<(D, Vec<Embedding>)>, EmbeddingError> { let documents_map = self .documents + .clone() .into_iter() - .map(|(id, document, docs)| (id, (document, docs))) + .enumerate() + .map(|(id, (document, _))| (id, document)) .collect::<HashMap<_, _>>(); - let embeddings = stream::iter(documents_map.iter()) + let embeddings = stream::iter(self.documents.into_iter().enumerate()) // Flatten the documents - .flat_map(|(id, (_, docs))| { - stream::iter(docs.iter().map(|doc| (id.clone(), doc.clone()))) + .flat_map(|(i, (_, embed_targets))| { + stream::iter( + embed_targets + .into_iter() + .map(move |target| (i, target.clone())), + ) }) // Chunk them into N (the emebdding API limit per request) .chunks(M::MAX_DOCUMENTS) // Generate the embeddings .map(|docs| async { - let (ids, docs): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); Ok::<_, EmbeddingError>( - ids.into_iter() - .zip(self.model.embed_documents(docs).await?.into_iter()) + documents + .into_iter() + .zip(self.model.embed_documents(embed_targets).await?.into_iter()) .collect::<Vec<_>>(), ) }) .boxed() // Parallelize the embeddings generation over 10 concurrent requests .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - .try_fold(vec![], |mut acc, mut embeddings| async move { - Ok({ - acc.append(&mut embeddings); - acc - }) + // .try_collect::<Vec<_>>() + // .await; + .try_fold(HashMap::new(), |mut acc, mut embeddings| async move { + embeddings.into_iter().for_each(|(i, embedding)| { + acc.entry(i).or_insert(vec![]).push(embedding); + }); + + Ok(acc) }) - .await?; + .await? + .iter() + .fold(vec![], |mut acc, (i, embeddings_vec)| { + acc.push(( + documents_map.get(i).cloned().unwrap(), + embeddings_vec.clone(), + )); + acc + }); - // Assemble the DocumentEmbeddings - let mut document_embeddings: HashMap<String, DocumentEmbeddings> = HashMap::new(); - embeddings.into_iter().for_each(|(id, embedding)| { - let (document, _) = documents_map.get(&id).expect("Document not found"); - let document_embedding = - document_embeddings - .entry(id.clone()) - .or_insert_with(|| DocumentEmbeddings { - id: id.clone(), - document: document.clone(), - embeddings: vec![], - }); + Ok(embeddings) + } +} - document_embedding.embeddings.push(embedding); - }); +impl<M: EmbeddingModel, D: Embeddable<Kind = SingleEmbedding> + Send + Clone> + EmbeddingsBuilder<M, D> +{ + pub async fn build(self) -> Result<Vec<(D, Embedding)>, EmbeddingError> { + let embeddings = + stream::iter(self.documents.into_iter().map(|(document, embed_target)| { + (document, embed_target.first().cloned().unwrap()) + })) + // Chunk them into N (the emebdding API limit per request) + .chunks(M::MAX_DOCUMENTS) + // Generate the embeddings + .map(|docs| async { + let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + Ok::<_, EmbeddingError>( + documents + .into_iter() + .zip(self.model.embed_documents(embed_targets).await?.into_iter()) + .collect::<Vec<_>>(), + ) + }) + .boxed() + // Parallelize the embeddings generation over 10 concurrent requests + .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) + .try_fold(vec![], |mut acc, embeddings| async move { + acc.extend(embeddings); + Ok(acc) + }) + .await?; - Ok(document_embeddings.into_values().collect()) + Ok(embeddings) } } diff --git a/rig-macros/Cargo.toml b/rig-macros/Cargo.toml deleted file mode 100644 index d1cd2cd9..00000000 --- a/rig-macros/Cargo.toml +++ /dev/null @@ -1,9 +0,0 @@ -[package] -name = "rig-macros" -version = "0.1.0" -edition = "2021" - -[dependencies] -serde_json = "1.0.128" -rig-macros-derive = { path = "./rig-macros-derive" } -serde = {version = "1.0.210", features = ["derive"]} diff --git a/rig-macros/src/lib.rs b/rig-macros/src/lib.rs deleted file mode 100644 index d7151f03..00000000 --- a/rig-macros/src/lib.rs +++ /dev/null @@ -1,65 +0,0 @@ -enum Kind { - Single, - Many, -} - -trait Embeddable { - type Kind; - fn embeddable(&self) -> Vec<String>; -} - -#[derive(serde::Serialize)] -pub struct JobStruct { - job_title: String, - company: String, -} - -mod something { - use super::JobStruct; - - pub fn embeddable(input: &JobStruct) -> Vec<String> { - vec![serde_json::to_string(input).unwrap()] - } -} - -#[cfg(test)] -mod tests { - use crate::JobStruct; - - use super::{Embeddable, Kind}; - use rig_macros_derive::Embedding; - - impl Embeddable for usize { - type Kind = Kind; - - fn embeddable(&self) -> Vec<String> { - vec![self.to_string()] - } - } - - #[derive(Embedding)] - struct SomeStruct { - #[embed] - id: usize, - name: String, - #[embed(embed_with = "super::something")] - job: JobStruct, - } - - #[test] - fn test_macro() { - let job_struct = JobStruct { - job_title: "developer".to_string(), - company: "playgrounds".to_string(), - }; - let some_struct = SomeStruct { - id: 1, - name: "John".to_string(), - job: job_struct, - }; - - println!("{:?}", some_struct.embeddable()); - - assert!(false) - } -} diff --git a/rig-macros/test.rs b/rig-macros/test.rs deleted file mode 100644 index 225a9b13..00000000 --- a/rig-macros/test.rs +++ /dev/null @@ -1,157 +0,0 @@ -/// Builder for creating a collection of embeddings -pub struct EmbeddingsBuilder<M: EmbeddingModel, T: Embeddable> { - model: M, - documents: Vec<(T, Vec<String>)>, -} - -trait Embeddable { - type Kind; - // Return list of strings that need to be embedded. - // Instead of Vec<String>, should be Vec<T: Serialize> - fn embeddable(&self) -> Vec<String>; -} - -type EmbeddingVector = Vec<f64>; - -impl<M: EmbeddingModel, T: Embeddable<Kind = Single>> EmbeddingsBuilder<M, T> { - pub fn build(&self) -> Result<Vec<(T, EmbeddingVector)>, EmbeddingError> { - self.documents.iter().map(|(doc, values_to_embed)| { - values_to_embed.iter().map(|value| { - let value_str = serde_json::to_string(value)?; - generate_embedding(value_str) - }) - }) - } -} - -impl<M: EmbeddingModel, T: Embeddable<Kind = Many>> EmbeddingsBuilder<M, T> { - pub fn build(&self) -> Result<Vec<(T, Vec<EmbeddingVector>)>, EmbeddingError> { - self.documents.iter().map(|(doc, values_to_embed)| { - values_to_embed.iter().map(|value| { - let value_str = serde_json::to_string(value)?; - generate_embedding(value_str) - }) - }) - } -} - -impl<M: EmbeddingModel, T: Embeddable> EmbeddingsBuilder<M, T> { - /// Create a new embedding builder with the given embedding model - pub fn new(model: M) -> Self { - Self { - model, - documents: vec![], - } - } - - pub fn add<T: Embeddable>( - mut self, - document: T, - ) -> Self { - let embed_documents: Vec<String> = document.embeddable(); - - self.documents.push(( - document, - embed_documents, - )); - self - } -} - - -// Example -#[derive(Embeddable)] -struct DictionaryEntry { - word: String, - #[embed] - definitions: String, -} - -#[derive(Embeddable)] -struct MetadataEmbedding { - pub id: String, - #[embed(with = serde_json::to_value)] - pub content: CategoryMetadata, - pub created: Option<DateTime<Utc>>, - pub modified: Option<DateTime<Utc>>, - pub dataset_ids: Vec<String>, -} - -#[derive(serde::Serialize)] -struct CategoryMetadata { - pub name: String, - pub description: String, - pub tags: Vec<String>, - pub links: Vec<String>, -} - -// Inside macro: -impl Embeddable for DictionaryEntry { - fn embeddable(&self) -> Vec<String> { - // Find the field tagged with #[embed] and return its value - // If there are no embedding tags, return the entire struct - } -} - -fn main() { - let embeddings: Vec<(DictionaryEntry, Vec<EmbeddingVector>)> = EmbeddingsBuilder::new(model.clone()) - .add(DictionaryEntry::new("blah", vec!["definition of blah"])) - .add(DictionaryEntry::new("foo", vec!["definition of foo"])) - .build()?; - - // In relational vector store like LanceDB, need to flatten result (create row for each item in definitions vector): - // Column: word (string) - // Column: definition (vector) - - // In document vector store like MongoDB, might need to merge the vector results back with their corresponding definition string: - // Field: word (string) - // Field: definitions - // // Field: definition (string) - // // Field: vector - - Ok(()) -} - - - -// Iterations: -// 1 - Multiple fields to embed? -#[derive(Embedding)] -struct DictionaryEntry { - word: String, - #[embed] - definitions: Vec<String>, - #[embed] - synonyms: Vec<String> -} - -// 2 - Embed recursion? Ex: -#[derive(Embedding)] -struct DictionaryEntry { - word: String, - #[embed] - definitions: Vec<Definition>, -} -struct Definition { - definition: String, - #[embed] - links: Vec<String> -} - -// { -// word: "blah", -// definitions: [ -// { -// definition: "definition of blah", -// links: ["link1", "link2"] -// }, -// { -// definition: "another definition for blah", -// links: ["link3"] -// } -// ] -// } - -// blah | definition of blah | link1 | embedding for link1 -// blah | definition of blah | link2 | embedding for link2 -// blah | another definition for blah | link3 | embedding for link3 \ No newline at end of file From 79754aa954b55e5d6e5fc466f5f880c9d9535f13 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Tue, 8 Oct 2024 14:28:03 -0400 Subject: [PATCH 07/91] feat: replace embedding logic with new embeddable trait and macro --- Cargo.lock | 6 +- rig-core/Cargo.toml | 4 +- rig-core/rig-core-derive/Cargo.toml | 2 +- .../src/{embedding.rs => embeddable.rs} | 40 +++++- rig-core/rig-core-derive/src/lib.rs | 6 +- rig-core/src/embeddings.rs | 135 ++++++++++++------ rig-core/src/providers/cohere.rs | 8 +- rig-core/src/providers/openai.rs | 7 +- rig-lancedb/Cargo.toml | 1 + rig-lancedb/examples/fixtures/lib.rs | 67 ++++++--- .../examples/vector_search_local_ann.rs | 26 ++-- .../examples/vector_search_local_enn.rs | 7 +- rig-lancedb/examples/vector_search_s3_ann.rs | 26 ++-- rig-mongodb/Cargo.toml | 2 + rig-mongodb/examples/vector_search_mongodb.rs | 71 ++++++--- rig-mongodb/src/lib.rs | 102 +++---------- 16 files changed, 287 insertions(+), 223 deletions(-) rename rig-core/rig-core-derive/src/{embedding.rs => embeddable.rs} (83%) diff --git a/Cargo.lock b/Cargo.lock index 19c004d7..311b2862 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3997,7 +3997,7 @@ dependencies = [ "futures", "ordered-float", "reqwest 0.11.27", - "rig-core-derive", + "rig-derive", "schemars", "serde", "serde_json", @@ -4008,7 +4008,7 @@ dependencies = [ ] [[package]] -name = "rig-core-derive" +name = "rig-derive" version = "0.1.0" dependencies = [ "indoc", @@ -4025,6 +4025,7 @@ dependencies = [ "futures", "lancedb", "rig-core", + "rig-derive", "serde", "serde_json", "tokio", @@ -4038,6 +4039,7 @@ dependencies = [ "futures", "mongodb", "rig-core", + "rig-derive", "serde", "serde_json", "tokio", diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 49efafd0..658faf4b 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -23,9 +23,9 @@ futures = "0.3.29" ordered-float = "4.2.0" schemars = "0.8.16" thiserror = "1.0.61" -rig-core-derive = { path = "./rig-core-derive" } +rig-derive = { path = "./rig-core-derive" } [dev-dependencies] anyhow = "1.0.75" tokio = { version = "1.34.0", features = ["full"] } -tracing-subscriber = "0.3.18" +tracing-subscriber = "0.3.18" \ No newline at end of file diff --git a/rig-core/rig-core-derive/Cargo.toml b/rig-core/rig-core-derive/Cargo.toml index 8f8a45d7..008f492b 100644 --- a/rig-core/rig-core-derive/Cargo.toml +++ b/rig-core/rig-core-derive/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "rig-core-derive" +name = "rig-derive" version = "0.1.0" edition = "2021" diff --git a/rig-core/rig-core-derive/src/embedding.rs b/rig-core/rig-core-derive/src/embeddable.rs similarity index 83% rename from rig-core/rig-core-derive/src/embedding.rs rename to rig-core/rig-core-derive/src/embeddable.rs index 2bfdb0ab..e2c34afe 100644 --- a/rig-core/rig-core-derive/src/embedding.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -1,8 +1,8 @@ use proc_macro::TokenStream; use quote::quote; use syn::{ - meta::ParseNestedMeta, parse_quote, punctuated::Punctuated, spanned::Spanned, Attribute, - DataStruct, ExprPath, Meta, Token, + meta::ParseNestedMeta, parse_quote, parse_str, punctuated::Punctuated, spanned::Spanned, + Attribute, DataStruct, ExprPath, Meta, Token, }; const EMBED: &str = "embed"; @@ -11,7 +11,7 @@ const EMBED_WITH: &str = "embed_with"; pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { let name = &input.ident; - let func_calls = + let (func_calls, embed_kind) = match &input.data { syn::Data::Struct(data_struct) => { // Handles fields tagged with #[embed] @@ -21,6 +21,7 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { add_struct_bounds(&mut input.generics, &field.ty); let field_name = field.ident; + quote! { self.#field_name.embeddable() } @@ -38,9 +39,9 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { }, )); - function_calls + (function_calls, data_struct.embed_kind().unwrap()) } - _ => vec![], + _ => panic!("Embeddable can only be derived for structs"), }; // Import the paths to the custom functions. @@ -60,10 +61,13 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { + use rig::embeddings::Embeddable; + use rig::embeddings::#embed_kind; + #(#custom_func_paths);* impl #impl_generics Embeddable for #name #ty_generics #where_clause { - type Kind = String; + type Kind = #embed_kind; fn embeddable(&self) -> Vec<String> { vec![ @@ -86,6 +90,13 @@ fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { }); } +fn embed_kind(field: &syn::Field) -> Result<syn::Expr, syn::Error> { + match &field.ty { + syn::Type::Array(_) => parse_str("ManyEmbedding"), + _ => parse_str("SingleEmbedding"), + } +} + trait AttributeParser { /// Finds and returns fields with simple #[embed] attribute tags only. fn basic_embed_fields(&self) -> impl Iterator<Item = syn::Field>; @@ -94,6 +105,23 @@ trait AttributeParser { fn custom_embed_fields( &self, ) -> Result<impl Iterator<Item = (syn::Field, syn::ExprPath)>, syn::Error>; + + /// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is 1, + /// returns the kind of embedding that field should be. + /// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is greater than 1, + /// return ManyEmbedding. + fn embed_kind(&self) -> Result<syn::Expr, syn::Error> { + let fields = self + .basic_embed_fields() + .chain(self.custom_embed_fields().unwrap().map(|(f, _)| f)) + .collect::<Vec<_>>(); + + if fields.len() == 1 { + fields.iter().map(embed_kind).next().unwrap() + } else { + parse_str("ManyEmbedding") + } + } } impl AttributeParser for DataStruct { diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index eb7abaa9..19f6a845 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -2,14 +2,14 @@ extern crate proc_macro; use proc_macro::TokenStream; use syn::{parse_macro_input, DeriveInput}; -mod embedding; +mod embeddable; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html -#[proc_macro_derive(Embedding, attributes(embed))] +#[proc_macro_derive(Embed, attributes(embed))] pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); - embedding::expand_derive_embedding(&mut input) + embeddable::expand_derive_embedding(&mut input) } diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index c1d281db..b044b406 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -38,13 +38,11 @@ //! // ... //! ``` -use std::{cmp::max, collections::HashMap}; +use std::{cmp::max, collections::HashMap, marker::PhantomData}; use futures::{stream, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; -use crate::tool::{ToolEmbedding, ToolSet, ToolType}; - #[derive(Debug, thiserror::Error)] pub enum EmbeddingError { /// Http error (e.g.: connection error, timeout, etc.) @@ -102,7 +100,7 @@ pub trait EmbeddingModel: Clone + Sync + Send { } /// Struct that holds a single document and its embedding. -#[derive(Clone, Default, Deserialize, Serialize)] +#[derive(Clone, Default, Deserialize, Serialize, Debug)] pub struct Embedding { /// The document that was embedded pub document: String, @@ -149,25 +147,29 @@ pub struct DocumentEmbeddings { pub document: serde_json::Value, pub embeddings: Vec<Embedding>, } - -struct SingleEmbedding; -struct ManyEmbedding; +pub trait EmbeddingKind {} +pub struct SingleEmbedding; +impl EmbeddingKind for SingleEmbedding {} +pub struct ManyEmbedding; +impl EmbeddingKind for ManyEmbedding {} pub trait Embeddable { - type Kind; + type Kind: EmbeddingKind; fn embeddable(&self) -> Vec<String>; } /// Builder for creating a collection of embeddings -pub struct EmbeddingsBuilder<M: EmbeddingModel, D: Embeddable + Send> { +pub struct EmbeddingsBuilder<M: EmbeddingModel, D: Embeddable, K: EmbeddingKind> { + kind: PhantomData<K>, model: M, documents: Vec<(D, Vec<String>)>, } -impl<M: EmbeddingModel, D: Embeddable + Send> EmbeddingsBuilder<M, D> { +impl<M: EmbeddingModel, D: Embeddable<Kind = K>, K: EmbeddingKind> EmbeddingsBuilder<M, D, K> { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { + kind: PhantomData, model, documents: vec![], } @@ -179,12 +181,22 @@ impl<M: EmbeddingModel, D: Embeddable + Send> EmbeddingsBuilder<M, D> { self.documents.push((document, embed_targets)); self } + + pub fn documents(mut self, documents: Vec<D>) -> EmbeddingsBuilder<M, D, D::Kind> { + documents.into_iter().for_each(|doc| { + let embed_targets = doc.embeddable(); + + self.documents.push((doc, embed_targets)); + }); + + self + } } -impl<M: EmbeddingModel, D: Embeddable<Kind = ManyEmbedding> + Send + Clone> - EmbeddingsBuilder<M, D> +impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> + EmbeddingsBuilder<M, D, ManyEmbedding> { - pub async fn build(self) -> Result<Vec<(D, Vec<Embedding>)>, EmbeddingError> { + pub async fn build(&self) -> Result<Vec<(D, Vec<Embedding>)>, EmbeddingError> { let documents_map = self .documents .clone() @@ -193,7 +205,7 @@ impl<M: EmbeddingModel, D: Embeddable<Kind = ManyEmbedding> + Send + Clone> .map(|(id, (document, _))| (id, document)) .collect::<HashMap<_, _>>(); - let embeddings = stream::iter(self.documents.into_iter().enumerate()) + let embeddings = stream::iter(self.documents.clone().into_iter().enumerate()) // Flatten the documents .flat_map(|(i, (_, embed_targets))| { stream::iter( @@ -219,13 +231,16 @@ impl<M: EmbeddingModel, D: Embeddable<Kind = ManyEmbedding> + Send + Clone> .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) // .try_collect::<Vec<_>>() // .await; - .try_fold(HashMap::new(), |mut acc, mut embeddings| async move { - embeddings.into_iter().for_each(|(i, embedding)| { - acc.entry(i).or_insert(vec![]).push(embedding); - }); + .try_fold( + HashMap::new(), + |mut acc: HashMap<_, Vec<_>>, embeddings| async move { + embeddings.into_iter().for_each(|(i, embedding)| { + acc.entry(i).or_default().push(embedding); + }); - Ok(acc) - }) + Ok(acc) + }, + ) .await? .iter() .fold(vec![], |mut acc, (i, embeddings_vec)| { @@ -240,35 +255,61 @@ impl<M: EmbeddingModel, D: Embeddable<Kind = ManyEmbedding> + Send + Clone> } } -impl<M: EmbeddingModel, D: Embeddable<Kind = SingleEmbedding> + Send + Clone> - EmbeddingsBuilder<M, D> +impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> + EmbeddingsBuilder<M, D, SingleEmbedding> { - pub async fn build(self) -> Result<Vec<(D, Embedding)>, EmbeddingError> { - let embeddings = - stream::iter(self.documents.into_iter().map(|(document, embed_target)| { - (document, embed_target.first().cloned().unwrap()) - })) - // Chunk them into N (the emebdding API limit per request) - .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings - .map(|docs| async { - let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - Ok::<_, EmbeddingError>( - documents - .into_iter() - .zip(self.model.embed_documents(embed_targets).await?.into_iter()) - .collect::<Vec<_>>(), - ) - }) - .boxed() - // Parallelize the embeddings generation over 10 concurrent requests - .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - .try_fold(vec![], |mut acc, embeddings| async move { - acc.extend(embeddings); - Ok(acc) - }) - .await?; + pub async fn build(&self) -> Result<Vec<(D, Embedding)>, EmbeddingError> { + let embeddings = stream::iter( + self.documents + .clone() + .into_iter() + .map(|(document, embed_target)| (document, embed_target.first().cloned().unwrap())), + ) + // Chunk them into N (the emebdding API limit per request) + .chunks(M::MAX_DOCUMENTS) + // Generate the embeddings + .map(|docs| async { + let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + Ok::<_, EmbeddingError>( + documents + .into_iter() + .zip(self.model.embed_documents(embed_targets).await?.into_iter()) + .collect::<Vec<_>>(), + ) + }) + .boxed() + // Parallelize the embeddings generation over 10 concurrent requests + .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) + .try_fold(vec![], |mut acc, embeddings| async move { + acc.extend(embeddings); + Ok(acc) + }) + .await?; Ok(embeddings) } } + +impl Embeddable for String { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec<String> { + vec![self.clone()] + } +} + +impl Embeddable for i32 { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec<String> { + vec![self.to_string()] + } +} + +impl<T: Embeddable> Embeddable for Vec<T> { + type Kind = ManyEmbedding; + + fn embeddable(&self) -> Vec<String> { + self.iter().flat_map(|i| i.embeddable()).collect() + } +} diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index ae874b21..c93709c7 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -13,7 +13,7 @@ use std::collections::HashMap; use crate::{ agent::AgentBuilder, completion::{self, CompletionError}, - embeddings::{self, EmbeddingError, EmbeddingsBuilder}, + embeddings::{self, Embeddable, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, json_utils, }; @@ -85,7 +85,11 @@ impl Client { EmbeddingModel::new(self.clone(), model, input_type, ndims) } - pub fn embeddings(&self, model: &str, input_type: &str) -> EmbeddingsBuilder<EmbeddingModel> { + pub fn embeddings<D: Embeddable>( + &self, + model: &str, + input_type: &str, + ) -> EmbeddingsBuilder<EmbeddingModel, D, D::Kind> { EmbeddingsBuilder::new(self.embedding_model(model, input_type)) } diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 8262e6ce..6bd1711f 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -11,7 +11,7 @@ use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, - embeddings::{self, EmbeddingError}, + embeddings::{self, Embeddable, EmbeddingError}, extractor::ExtractorBuilder, json_utils, }; @@ -121,7 +121,10 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder<EmbeddingModel> { + pub fn embeddings<D: Embeddable>( + &self, + model: &str, + ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D, D::Kind> { embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) } diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index 031df2d3..6ee41d8d 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] lancedb = "0.10.0" rig-core = { path = "../rig-core", version = "0.2.1" } +rig-derive = { path = "../rig-core/rig-core-derive" } arrow-array = "52.2.0" serde_json = "1.0.128" serde = "1.0.210" diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 9a91432e..322c0a6a 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -2,13 +2,52 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; -use rig::embeddings::DocumentEmbeddings; +use rig::embeddings::Embedding; +use rig_derive::Embed; +use serde::Deserialize; + +#[derive(Embed, Clone)] +pub struct FakeDefinition { + id: String, + #[embed] + definition: String, +} + +#[derive(Deserialize, Debug)] +pub struct VectorSearchResult { + pub id: String, + pub definition: String, +} + +pub fn fake_definitions() -> Vec<FakeDefinition> { + vec![ + FakeDefinition { + id: "doc0".to_string(), + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string() + }, + FakeDefinition { + id: "doc1".to_string(), + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string() + }, + FakeDefinition { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + } + ] +} + +pub fn fake_definition(id: String) -> FakeDefinition { + FakeDefinition { + id, + definition: "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string() + } +} // Schema of table in LanceDB. pub fn schema(dims: usize) -> Schema { Schema::new(Fields::from(vec![ Field::new("id", DataType::Utf8, false), - Field::new("content", DataType::Utf8, false), + Field::new("definition", DataType::Utf8, false), Field::new( "embedding", DataType::FixedSizeList( @@ -22,46 +61,34 @@ pub fn schema(dims: usize) -> Schema { // Convert DocumentEmbeddings objects to a RecordBatch. pub fn as_record_batch( - records: Vec<DocumentEmbeddings>, + records: Vec<(FakeDefinition, Embedding)>, dims: usize, ) -> Result<RecordBatch, lancedb::arrow::arrow_schema::ArrowError> { let id = StringArray::from_iter_values( records .iter() - .flat_map(|record| (0..record.embeddings.len()).map(|i| format!("{}-{i}", record.id))) + .map(|(FakeDefinition { id, .. }, _)| id) .collect::<Vec<_>>(), ); - let content = StringArray::from_iter_values( + let definition = StringArray::from_iter_values( records .iter() - .flat_map(|record| { - record - .embeddings - .iter() - .map(|embedding| embedding.document.clone()) - }) + .map(|(_, Embedding { document, .. })| document) .collect::<Vec<_>>(), ); let embedding = FixedSizeListArray::from_iter_primitive::<Float64Type, _, _>( records .into_iter() - .flat_map(|record| { - record - .embeddings - .into_iter() - .map(|embedding| embedding.vec.into_iter().map(Some).collect::<Vec<_>>()) - .map(Some) - .collect::<Vec<_>>() - }) + .map(|(_, Embedding { vec, .. })| Some(vec.into_iter().map(Some).collect::<Vec<_>>())) .collect::<Vec<_>>(), dims as i32, ); RecordBatch::try_from_iter(vec![ ("id", Arc::new(id) as ArrayRef), - ("content", Arc::new(content) as ArrayRef), + ("definition", Arc::new(definition) as ArrayRef), ("embedding", Arc::new(embedding) as ArrayRef), ]) } diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 3ecd6b23..3ac925bb 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema}; +use fixture::{as_record_batch, fake_definition, fake_definitions, schema, VectorSearchResult}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ @@ -9,17 +9,10 @@ use rig::{ providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -use serde::Deserialize; #[path = "./fixtures/lib.rs"] mod fixture; -#[derive(Deserialize, Debug)] -pub struct VectorSearchResult { - pub id: String, - pub content: String, -} - #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). @@ -32,18 +25,15 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; - // Set up test data for RAG demo - let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); - - // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. - let definitions = vec![definition; 256]; - // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") - .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") - .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") - .simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect()) + .documents(fake_definitions()) + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + .documents( + (0..256) + .map(|i| fake_definition(format!("doc{}", i))) + .collect(), + ) .build() .await?; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 5932dcd0..1bf69481 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema}; +use fixture::{as_record_batch, fake_definitions, schema}; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, @@ -21,10 +21,9 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") - .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") - .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") + .documents(fake_definitions()) .build() .await?; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 70f0c8c5..358a1755 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema}; +use fixture::{as_record_batch, fake_definition, fake_definitions, schema, VectorSearchResult}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, @@ -9,17 +9,10 @@ use rig::{ vector_store::VectorStoreIndex, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -use serde::Deserialize; #[path = "./fixtures/lib.rs"] mod fixture; -#[derive(Deserialize, Debug)] -pub struct VectorSearchResult { - pub id: String, - pub content: String, -} - // Note: see docs to deploy LanceDB on other cloud providers such as google and azure. // https://lancedb.github.io/lancedb/guides/storage/ #[tokio::main] @@ -38,18 +31,15 @@ async fn main() -> Result<(), anyhow::Error> { .execute() .await?; - // Set up test data for RAG demo - let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); - - // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. - let definitions = vec![definition; 256]; - // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") - .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") - .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") - .simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect()) + .documents(fake_definitions()) + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + .documents( + (0..256) + .map(|i| fake_definition(format!("doc{}", i))) + .collect(), + ) .build() .await?; diff --git a/rig-mongodb/Cargo.toml b/rig-mongodb/Cargo.toml index 6f313838..78b48892 100644 --- a/rig-mongodb/Cargo.toml +++ b/rig-mongodb/Cargo.toml @@ -13,6 +13,8 @@ repository = "https://github.com/0xPlaygrounds/rig" futures = "0.3.30" mongodb = "2.8.2" rig-core = { path = "../rig-core", version = "0.2.1" } +rig-derive = { path = "../rig-core/rig-core-derive" } + serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.117" tracing = "0.1.40" diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 3d062de3..4ddf1f99 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,13 +1,28 @@ -use mongodb::{options::ClientOptions, Client as MongoClient, Collection}; +use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; +use rig::{embeddings::Embedding, providers::openai::TEXT_EMBEDDING_ADA_002}; +use rig_derive::Embed; +use serde::{Deserialize, Serialize}; use std::env; use rig::{ - embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, - providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{VectorStore, VectorStoreIndex}, + embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; +#[derive(Embed, Clone)] +struct FakeDefinition { + id: String, + #[embed] + definition: String, +} + +#[derive(Serialize, Debug, Deserialize)] +struct Document { + #[serde(rename = "_id")] + id: String, + definition: Embedding, +} + #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client @@ -25,38 +40,62 @@ async fn main() -> Result<(), anyhow::Error> { MongoClient::with_options(options).expect("MongoDB client options should be valid"); // Initialize MongoDB vector store - let collection: Collection<DocumentEmbeddings> = mongodb_client + let collection: Collection<Document> = mongodb_client .database("knowledgebase") .collection("context"); - let mut vector_store = MongoDbVectorStore::new(collection); - // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + let fake_definitions = vec![ + FakeDefinition { + id: "doc0".to_string(), + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string() + }, + FakeDefinition { + id: "doc1".to_string(), + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string() + }, + FakeDefinition { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + } + ]; + let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .documents(fake_definitions) .build() .await?; - // Add embeddings to vector store - match vector_store.add_documents(embeddings).await { + let mongo_documents = embeddings + .iter() + .map(|(FakeDefinition { id, .. }, embedding)| Document { + id: id.clone(), + definition: embedding.clone(), + }) + .collect::<Vec<_>>(); + + match collection.insert_many(mongo_documents, None).await { Ok(_) => println!("Documents added successfully"), Err(e) => println!("Error adding documents: {:?}", e), - } + }; + + let vector_store = MongoDbVectorStore::new(collection); // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model, "vector_index", SearchParams::default()); + let index = vector_store.index( + model, + "definitions_vector_index", + SearchParams::new("definition.vec"), + ); // Query the index let results = index - .top_n::<DocumentEmbeddings>("What is a linglingdong?", 1) + .top_n::<Document>("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.document)) + .map(|(score, id, doc)| (score, id, doc.definition.document)) .collect::<Vec<_>>(); println!("Results: {:?}", results); diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 43869989..5ce33105 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -2,84 +2,23 @@ use futures::StreamExt; use mongodb::bson::{self, doc}; use rig::{ - embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel}, - vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, + embeddings::{Embedding, EmbeddingModel}, + vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; /// A MongoDB vector store. -pub struct MongoDbVectorStore { - collection: mongodb::Collection<DocumentEmbeddings>, +pub struct MongoDbVectorStore<C> { + collection: mongodb::Collection<C>, } fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } -impl VectorStore for MongoDbVectorStore { - type Q = mongodb::bson::Document; - - async fn add_documents( - &mut self, - documents: Vec<DocumentEmbeddings>, - ) -> Result<(), VectorStoreError> { - self.collection - .insert_many(documents, None) - .await - .map_err(mongodb_to_rig_error)?; - Ok(()) - } - - async fn get_document_embeddings( - &self, - id: &str, - ) -> Result<Option<DocumentEmbeddings>, VectorStoreError> { - self.collection - .find_one(doc! { "_id": id }, None) - .await - .map_err(mongodb_to_rig_error) - } - - async fn get_document<T: for<'a> serde::Deserialize<'a>>( - &self, - id: &str, - ) -> Result<Option<T>, VectorStoreError> { - Ok(self - .collection - .clone_with_type::<String>() - .aggregate( - [ - doc! {"$match": { "_id": id}}, - doc! {"$project": { "document": 1 }}, - doc! {"$replaceRoot": { "newRoot": "$document" }}, - ], - None, - ) - .await - .map_err(mongodb_to_rig_error)? - .with_type::<String>() - .next() - .await - .transpose() - .map_err(mongodb_to_rig_error)? - .map(|doc| serde_json::from_str(&doc)) - .transpose()?) - } - - async fn get_document_by_query( - &self, - query: Self::Q, - ) -> Result<Option<DocumentEmbeddings>, VectorStoreError> { - self.collection - .find_one(query, None) - .await - .map_err(mongodb_to_rig_error) - } -} - -impl MongoDbVectorStore { +impl<C> MongoDbVectorStore<C> { /// Create a new `MongoDbVectorStore` from a MongoDB collection. - pub fn new(collection: mongodb::Collection<DocumentEmbeddings>) -> Self { + pub fn new(collection: mongodb::Collection<C>) -> Self { Self { collection } } @@ -92,20 +31,20 @@ impl MongoDbVectorStore { model: M, index_name: &str, search_params: SearchParams, - ) -> MongoDbVectorIndex<M> { + ) -> MongoDbVectorIndex<M, C> { MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params) } } /// A vector index for a MongoDB collection. -pub struct MongoDbVectorIndex<M: EmbeddingModel> { - collection: mongodb::Collection<DocumentEmbeddings>, +pub struct MongoDbVectorIndex<M: EmbeddingModel, C> { + collection: mongodb::Collection<C>, model: M, index_name: String, search_params: SearchParams, } -impl<M: EmbeddingModel> MongoDbVectorIndex<M> { +impl<M: EmbeddingModel, C> MongoDbVectorIndex<M, C> { /// Vector search stage of aggregation pipeline of mongoDB collection. /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex. fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document { @@ -113,12 +52,13 @@ impl<M: EmbeddingModel> MongoDbVectorIndex<M> { filter, exact, num_candidates, + path, } = &self.search_params; doc! { "$vectorSearch": { "index": &self.index_name, - "path": "embeddings.vec", + "path": path, "queryVector": &prompt_embedding.vec, "numCandidates": num_candidates.unwrap_or((n * 10) as u32), "limit": n as u32, @@ -139,9 +79,9 @@ impl<M: EmbeddingModel> MongoDbVectorIndex<M> { } } -impl<M: EmbeddingModel> MongoDbVectorIndex<M> { +impl<M: EmbeddingModel, C> MongoDbVectorIndex<M, C> { pub fn new( - collection: mongodb::Collection<DocumentEmbeddings>, + collection: mongodb::Collection<C>, model: M, index_name: &str, search_params: SearchParams, @@ -159,17 +99,19 @@ impl<M: EmbeddingModel> MongoDbVectorIndex<M> { /// on each of the fields pub struct SearchParams { filter: mongodb::bson::Document, + path: String, exact: Option<bool>, num_candidates: Option<u32>, } impl SearchParams { /// Initializes a new `SearchParams` with default values. - pub fn new() -> Self { + pub fn new(path: &str) -> Self { Self { filter: doc! {}, exact: None, num_candidates: None, + path: path.to_string(), } } @@ -199,13 +141,9 @@ impl SearchParams { } } -impl Default for SearchParams { - fn default() -> Self { - Self::new() - } -} - -impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for MongoDbVectorIndex<M> { +impl<M: EmbeddingModel + std::marker::Sync + Send, C: std::marker::Sync + Send> VectorStoreIndex + for MongoDbVectorIndex<M, C> +{ async fn top_n<T: for<'a> Deserialize<'a> + std::marker::Send>( &self, query: &str, From 3438fab667cfd787eb16e70d2266e2930e2c3fa0 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Tue, 8 Oct 2024 17:47:07 -0400 Subject: [PATCH 08/91] refactor: refactor rag examples, delete document embedding struct --- rig-core/rig-core-derive/src/embeddable.rs | 9 ++++++- rig-core/src/embeddings.rs | 16 ------------ rig-lancedb/examples/fixtures/lib.rs | 18 +++++-------- .../examples/vector_search_local_ann.rs | 4 +-- rig-lancedb/examples/vector_search_s3_ann.rs | 4 +-- rig-mongodb/examples/vector_search_mongodb.rs | 26 ++++++++++--------- 6 files changed, 32 insertions(+), 45 deletions(-) diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index e2c34afe..ec52d952 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -7,6 +7,7 @@ use syn::{ const EMBED: &str = "embed"; const EMBED_WITH: &str = "embed_with"; +const VEC_TYPE: &str = "Vec"; pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { let name = &input.ident; @@ -92,7 +93,13 @@ fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { fn embed_kind(field: &syn::Field) -> Result<syn::Expr, syn::Error> { match &field.ty { - syn::Type::Array(_) => parse_str("ManyEmbedding"), + syn::Type::Path(path) => { + if path.path.segments.first().unwrap().ident == VEC_TYPE { + parse_str("ManyEmbedding") + } else { + parse_str("SingleEmbedding") + } + }, _ => parse_str("SingleEmbedding"), } } diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index b044b406..458c6eb3 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -131,22 +131,6 @@ impl Embedding { } } -/// Struct that holds a document and its embeddings. -/// -/// The struct is designed to model any kind of documents that can be serialized to JSON -/// (including a simple string). -/// -/// Moreover, it can hold multiple embeddings for the same document, thus allowing a -/// large document to be retrieved from a query that matches multiple smaller and -/// distinct text documents. For example, if the document is a textbook, a summary of -/// each chapter could serve as the book's embeddings. -#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)] -pub struct DocumentEmbeddings { - #[serde(rename = "_id")] - pub id: String, - pub document: serde_json::Value, - pub embeddings: Vec<Embedding>, -} pub trait EmbeddingKind {} pub struct SingleEmbedding; impl EmbeddingKind for SingleEmbedding {} diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 322c0a6a..e8ebead9 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -6,28 +6,22 @@ use rig::embeddings::Embedding; use rig_derive::Embed; use serde::Deserialize; -#[derive(Embed, Clone)] +#[derive(Embed, Clone, Deserialize, Debug)] pub struct FakeDefinition { id: String, #[embed] definition: String, } -#[derive(Deserialize, Debug)] -pub struct VectorSearchResult { - pub id: String, - pub definition: String, -} - pub fn fake_definitions() -> Vec<FakeDefinition> { vec![ FakeDefinition { id: "doc0".to_string(), - definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string() + definition: "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.".to_string() }, FakeDefinition { id: "doc1".to_string(), - definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string() + definition: "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.".to_string() }, FakeDefinition { id: "doc2".to_string(), @@ -39,7 +33,7 @@ pub fn fake_definitions() -> Vec<FakeDefinition> { pub fn fake_definition(id: String) -> FakeDefinition { FakeDefinition { id, - definition: "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string() + definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() } } @@ -59,7 +53,7 @@ pub fn schema(dims: usize) -> Schema { ])) } -// Convert DocumentEmbeddings objects to a RecordBatch. +// Convert FakeDefinition objects and their embedding to a RecordBatch. pub fn as_record_batch( records: Vec<(FakeDefinition, Embedding)>, dims: usize, @@ -74,7 +68,7 @@ pub fn as_record_batch( let definition = StringArray::from_iter_values( records .iter() - .map(|(_, Embedding { document, .. })| document) + .map(|(FakeDefinition { definition, .. }, _)| definition) .collect::<Vec<_>>(), ); diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 3ac925bb..af82ad49 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definition, fake_definitions, schema, VectorSearchResult}; +use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDefinition}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ @@ -62,7 +62,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n::<VectorSearchResult>("My boss says I zindle too much, what does that mean?", 1) + .top_n::<FakeDefinition>("My boss says I zindle too much, what does that mean?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 358a1755..17c4cd7f 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definition, fake_definitions, schema, VectorSearchResult}; +use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, @@ -74,7 +74,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n::<VectorSearchResult>("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) + .top_n::<FakeDefinition>("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 4ddf1f99..18e0873a 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,5 +1,5 @@ use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; -use rig::{embeddings::Embedding, providers::openai::TEXT_EMBEDDING_ADA_002}; +use rig::providers::openai::TEXT_EMBEDDING_ADA_002; use rig_derive::Embed; use serde::{Deserialize, Serialize}; use std::env; @@ -9,18 +9,22 @@ use rig::{ }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; -#[derive(Embed, Clone)] +// Shape of data that needs to be RAG'ed. +// The definition field will be used to generate embeddings. +#[derive(Embed, Clone, Deserialize, Debug)] struct FakeDefinition { id: String, #[embed] definition: String, } -#[derive(Serialize, Debug, Deserialize)] +// Shape of the document to be stored in MongoDB. +#[derive(Serialize, Debug)] struct Document { #[serde(rename = "_id")] id: String, - definition: Embedding, + definition: String, + embedding: Vec<f64>, } #[tokio::main] @@ -69,9 +73,10 @@ async fn main() -> Result<(), anyhow::Error> { let mongo_documents = embeddings .iter() - .map(|(FakeDefinition { id, .. }, embedding)| Document { + .map(|(FakeDefinition { id, definition }, embedding)| Document { id: id.clone(), - definition: embedding.clone(), + definition: definition.clone(), + embedding: embedding.vec.clone(), }) .collect::<Vec<_>>(); @@ -87,16 +92,13 @@ async fn main() -> Result<(), anyhow::Error> { let index = vector_store.index( model, "definitions_vector_index", - SearchParams::new("definition.vec"), + SearchParams::new("embedding"), ); // Query the index let results = index - .top_n::<Document>("What is a linglingdong?", 1) - .await? - .into_iter() - .map(|(score, id, doc)| (score, id, doc.definition.document)) - .collect::<Vec<_>>(); + .top_n::<FakeDefinition>("What is a linglingdong?", 1) + .await?; println!("Results: {:?}", results); From 597e6c3e7e82b91a2df179cb04c66f29d83b1696 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Tue, 8 Oct 2024 21:45:43 -0400 Subject: [PATCH 09/91] feat: remove document embedding from in memory store --- rig-core/examples/rag.rs | 48 ++++- rig-core/examples/vector_search.rs | 59 +++++- rig-core/examples/vector_search_cohere.rs | 57 +++++- rig-core/rig-core-derive/src/embeddable.rs | 2 +- rig-core/src/vector_store/in_memory_store.rs | 179 ++++-------------- rig-core/src/vector_store/mod.rs | 18 +- rig-mongodb/examples/vector_search_mongodb.rs | 2 +- 7 files changed, 191 insertions(+), 174 deletions(-) diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 3abd8ee9..55a99e63 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -6,6 +6,15 @@ use rig::{ providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, }; +use rig_derive::Embed; +use serde::Serialize; + +#[derive(Embed, Clone, Serialize, Eq, PartialEq, Default)] +struct FakeDefinition { + id: String, + #[embed] + definitions: Vec<String>, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -18,14 +27,45 @@ async fn main() -> Result<(), anyhow::Error> { // Create vector store, compute embeddings and load them in the store let mut vector_store = InMemoryVectorStore::default(); + let fake_definitions = vec![ + FakeDefinition { + id: "doc0".to_string(), + definitions: vec![ + "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + "Definition of a *flurbo*: A unit of currency used in a bizarre or fantastical world, often associated with eccentric societies or sci-fi settings.".to_string() + ] + }, + FakeDefinition { + id: "doc1".to_string(), + definitions: vec![ + "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "Definition of a *glarb-glarb*: A mysterious, bubbling substance often found in swamps, alien planets, or under mysterious circumstances.".to_string() + ] + }, + FakeDefinition { + id: "doc2".to_string(), + definitions: vec![ + "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + ] + } + ]; + let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .documents(fake_definitions) .build() .await?; - vector_store.add_documents(embeddings).await?; + vector_store + .add_documents( + embeddings + .into_iter() + .enumerate() + .map(|(i, (fake_definition, embeddings))| { + (format!("doc{i}"), fake_definition, embeddings) + }) + .collect(), + ) + .await?; // Create vector store index let index = vector_store.index(embedding_model); diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index e0d68bef..095ca296 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,10 +1,19 @@ use std::env; use rig::{ - embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex}, + vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, }; +use rig_derive::Embed; +use serde::{Deserialize, Serialize}; + +#[derive(Embed, Clone, Serialize, Default, Eq, PartialEq, Deserialize, Debug)] +struct FakeDefinition { + id: String, + #[embed] + definitions: Vec<String>, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -14,20 +23,54 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + let fake_definitions = vec![ + FakeDefinition { + id: "doc0".to_string(), + definitions: vec![ + "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + "Definition of a *flurbo*: A unit of currency used in a bizarre or fantastical world, often associated with eccentric societies or sci-fi settings.".to_string() + ] + }, + FakeDefinition { + id: "doc1".to_string(), + definitions: vec![ + "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "Definition of a *glarb-glarb*: A mysterious, bubbling substance often found in swamps, alien planets, or under mysterious circumstances.".to_string() + ] + }, + FakeDefinition { + id: "doc2".to_string(), + definitions: vec![ + "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + ] + } + ]; + let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .documents(fake_definitions) .build() .await?; - let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?; + let mut store = InMemoryVectorStore::default(); + store + .add_documents( + embeddings + .into_iter() + .enumerate() + .map(|(i, (fake_definition, embeddings))| { + (format!("doc{i}"), fake_definition, embeddings) + }) + .collect(), + ) + .await?; + + let index = store.index(model); let results = index - .top_n::<DocumentEmbeddings>("What is a linglingdong?", 1) + .top_n::<FakeDefinition>("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.document)) + .map(|(score, id, doc)| (score, id, doc)) .collect::<Vec<_>>(); println!("Results: {:?}", results); diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index a49ac231..6be094b4 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,10 +1,19 @@ use std::env; use rig::{ - embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, + embeddings::EmbeddingsBuilder, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, }; +use rig_derive::Embed; +use serde::{Deserialize, Serialize}; + +#[derive(Embed, Clone, Serialize, Default, Eq, PartialEq, Deserialize, Debug)] +struct FakeDefinition { + id: String, + #[embed] + definitions: Vec<String>, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -15,24 +24,54 @@ async fn main() -> Result<(), anyhow::Error> { let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document"); let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query"); - let mut vector_store = InMemoryVectorStore::default(); + let fake_definitions = vec![ + FakeDefinition { + id: "doc0".to_string(), + definitions: vec![ + "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + "Definition of a *flurbo*: A unit of currency used in a bizarre or fantastical world, often associated with eccentric societies or sci-fi settings.".to_string() + ] + }, + FakeDefinition { + id: "doc1".to_string(), + definitions: vec![ + "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "Definition of a *glarb-glarb*: A mysterious, bubbling substance often found in swamps, alien planets, or under mysterious circumstances.".to_string() + ] + }, + FakeDefinition { + id: "doc2".to_string(), + definitions: vec![ + "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + ] + } + ]; let embeddings = EmbeddingsBuilder::new(document_model) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .documents(fake_definitions) .build() .await?; - vector_store.add_documents(embeddings).await?; + let mut store = InMemoryVectorStore::default(); + store + .add_documents( + embeddings + .into_iter() + .enumerate() + .map(|(i, (fake_definition, embeddings))| { + (format!("doc{i}"), fake_definition, embeddings) + }) + .collect(), + ) + .await?; - let index = vector_store.index(search_model); + let index = store.index(search_model); let results = index - .top_n::<DocumentEmbeddings>("What is a linglingdong?", 1) + .top_n::<FakeDefinition>("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.document)) + .map(|(score, id, doc)| (score, id, doc)) .collect::<Vec<_>>(); println!("Results: {:?}", results); diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index ec52d952..8546f84a 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -99,7 +99,7 @@ fn embed_kind(field: &syn::Field) -> Result<syn::Expr, syn::Error> { } else { parse_str("SingleEmbedding") } - }, + } _ => parse_str("SingleEmbedding"), } } diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index a5db505f..5ed98671 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -8,27 +8,26 @@ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; use super::{VectorStore, VectorStoreError, VectorStoreIndex}; -use crate::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel, EmbeddingsBuilder}; +use crate::embeddings::{Embedding, EmbeddingModel}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. -#[derive(Clone, Default, Deserialize, Serialize)] -pub struct InMemoryVectorStore { +#[derive(Clone, Default)] +pub struct InMemoryVectorStore<D: Serialize> { /// The embeddings are stored in a HashMap with the document ID as the key. - embeddings: HashMap<String, DocumentEmbeddings>, + embeddings: HashMap<String, (D, Vec<Embedding>)>, } -impl InMemoryVectorStore { +impl<D: Serialize + Eq> InMemoryVectorStore<D> { /// Implement vector search on InMemoryVectorStore. /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for InMemoryVectorStore. - fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { + fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking<D> { // Sort documents by best embedding distance - let mut docs: EmbeddingRanking = BinaryHeap::new(); + let mut docs = BinaryHeap::new(); - for (id, doc_embeddings) in self.embeddings.iter() { + for (id, (doc, embeddings)) in self.embeddings.iter() { // Get the best context for the document given the prompt - if let Some((distance, embed_doc)) = doc_embeddings - .embeddings + if let Some((distance, embed_doc)) = embeddings .iter() .map(|embedding| { ( @@ -38,12 +37,7 @@ impl InMemoryVectorStore { }) .min_by(|a, b| a.0.cmp(&b.0)) { - docs.push(Reverse(RankingItem( - distance, - id, - doc_embeddings, - embed_doc, - ))); + docs.push(Reverse(RankingItem(distance, id, doc, embed_doc))); }; // If the heap size exceeds n, pop the least old element. @@ -67,73 +61,51 @@ impl InMemoryVectorStore { /// RankingItem(distance, document_id, document, embed_doc) #[derive(Eq, PartialEq)] -struct RankingItem<'a>( - OrderedFloat<f64>, - &'a String, - &'a DocumentEmbeddings, - &'a String, -); - -impl Ord for RankingItem<'_> { +struct RankingItem<'a, D: Serialize>(OrderedFloat<f64>, &'a String, &'a D, &'a String); + +impl<D: Serialize + Eq> Ord for RankingItem<'_, D> { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.0.cmp(&other.0) } } -impl PartialOrd for RankingItem<'_> { +impl<D: Serialize + Eq> PartialOrd for RankingItem<'_, D> { fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { Some(self.cmp(other)) } } -type EmbeddingRanking<'a> = BinaryHeap<Reverse<RankingItem<'a>>>; +type EmbeddingRanking<'a, D> = BinaryHeap<Reverse<RankingItem<'a, D>>>; -impl VectorStore for InMemoryVectorStore { +impl<D: Serialize + Send + Sync + Clone> VectorStore<D> for InMemoryVectorStore<D> { type Q = (); async fn add_documents( &mut self, - documents: Vec<DocumentEmbeddings>, + documents: Vec<(String, D, Vec<Embedding>)>, ) -> Result<(), VectorStoreError> { - for doc in documents { - self.embeddings.insert(doc.id.clone(), doc); + for (id, doc, embeddings) in documents { + self.embeddings.insert(id, (doc, embeddings)); } Ok(()) } - async fn get_document<T: for<'a> Deserialize<'a>>( - &self, - id: &str, - ) -> Result<Option<T>, VectorStoreError> { - Ok(self - .embeddings - .get(id) - .map(|document| serde_json::from_value(document.document.clone())) - .transpose()?) - } - - async fn get_document_embeddings( - &self, - id: &str, - ) -> Result<Option<DocumentEmbeddings>, VectorStoreError> { - Ok(self.embeddings.get(id).cloned()) + async fn get_document_embeddings(&self, id: &str) -> Result<Option<D>, VectorStoreError> { + Ok(self.embeddings.get(id).cloned().map(|(doc, _)| doc)) } - async fn get_document_by_query( - &self, - _query: Self::Q, - ) -> Result<Option<DocumentEmbeddings>, VectorStoreError> { + async fn get_document_by_query(&self, _query: Self::Q) -> Result<Option<D>, VectorStoreError> { Ok(None) } } -impl InMemoryVectorStore { - pub fn index<M: EmbeddingModel>(self, model: M) -> InMemoryVectorIndex<M> { +impl<D: Serialize> InMemoryVectorStore<D> { + pub fn index<M: EmbeddingModel>(self, model: M) -> InMemoryVectorIndex<M, D> { InMemoryVectorIndex::new(model, self) } - pub fn iter(&self) -> impl Iterator<Item = (&String, &DocumentEmbeddings)> { + pub fn iter(&self) -> impl Iterator<Item = (&String, &(D, Vec<Embedding>))> { self.embeddings.iter() } @@ -144,54 +116,19 @@ impl InMemoryVectorStore { pub fn is_empty(&self) -> bool { self.embeddings.is_empty() } - - /// Uitilty method to create an InMemoryVectorStore from a list of embeddings. - pub async fn from_embeddings( - embeddings: Vec<DocumentEmbeddings>, - ) -> Result<Self, VectorStoreError> { - let mut store = Self::default(); - store.add_documents(embeddings).await?; - Ok(store) - } - - /// Create an InMemoryVectorStore from a list of documents. - /// The documents are serialized to JSON and embedded using the provided embedding model. - /// The resulting embeddings are stored in an InMemoryVectorStore created by the method. - pub async fn from_documents<M: EmbeddingModel, T: Serialize>( - embedding_model: M, - documents: &[(String, T)], - ) -> Result<Self, VectorStoreError> { - let embeddings = documents - .iter() - .fold( - EmbeddingsBuilder::new(embedding_model), - |builder, (id, doc)| { - builder.json_document( - id, - serde_json::to_value(doc).expect("Document should be serializable"), - vec![serde_json::to_string(doc).expect("Document should be serializable")], - ) - }, - ) - .build() - .await?; - - let store = Self::from_embeddings(embeddings).await?; - Ok(store) - } } -pub struct InMemoryVectorIndex<M: EmbeddingModel> { +pub struct InMemoryVectorIndex<M: EmbeddingModel, D: Serialize> { model: M, - pub store: InMemoryVectorStore, + pub store: InMemoryVectorStore<D>, } -impl<M: EmbeddingModel> InMemoryVectorIndex<M> { - pub fn new(model: M, store: InMemoryVectorStore) -> Self { +impl<M: EmbeddingModel, D: Serialize> InMemoryVectorIndex<M, D> { + pub fn new(model: M, store: InMemoryVectorStore<D>) -> Self { Self { model, store } } - pub fn iter(&self) -> impl Iterator<Item = (&String, &DocumentEmbeddings)> { + pub fn iter(&self) -> impl Iterator<Item = (&String, &(D, Vec<Embedding>))> { self.store.iter() } @@ -202,49 +139,11 @@ impl<M: EmbeddingModel> InMemoryVectorIndex<M> { pub fn is_empty(&self) -> bool { self.store.is_empty() } - - /// Create an InMemoryVectorIndex from a list of documents. - /// The documents are serialized to JSON and embedded using the provided embedding model. - /// The resulting embeddings are stored in an InMemoryVectorStore created by the method. - /// The InMemoryVectorIndex is then created from the store and the provided query model. - pub async fn from_documents<T: Serialize>( - embedding_model: M, - query_model: M, - documents: &[(String, T)], - ) -> Result<Self, VectorStoreError> { - let mut store = InMemoryVectorStore::default(); - - let embeddings = documents - .iter() - .fold( - EmbeddingsBuilder::new(embedding_model), - |builder, (id, doc)| { - builder.json_document( - id, - serde_json::to_value(doc).expect("Document should be serializable"), - vec![serde_json::to_string(doc).expect("Document should be serializable")], - ) - }, - ) - .build() - .await?; - - store.add_documents(embeddings).await?; - Ok(store.index(query_model)) - } - - /// Utility method to create an InMemoryVectorIndex from a list of embeddings - /// and an embedding model. - pub async fn from_embeddings( - query_model: M, - embeddings: Vec<DocumentEmbeddings>, - ) -> Result<Self, VectorStoreError> { - let store = InMemoryVectorStore::from_embeddings(embeddings).await?; - Ok(store.index(query_model)) - } } -impl<M: EmbeddingModel + std::marker::Sync> VectorStoreIndex for InMemoryVectorIndex<M> { +impl<M: EmbeddingModel + std::marker::Sync, D: Serialize + Sync + Send + Eq> VectorStoreIndex + for InMemoryVectorIndex<M, D> +{ async fn top_n<T: for<'a> Deserialize<'a>>( &self, query: &str, @@ -256,12 +155,14 @@ impl<M: EmbeddingModel + std::marker::Sync> VectorStoreIndex for InMemoryVectorI // Return n best docs.into_iter() - .map(|Reverse(RankingItem(distance, _, doc, _))| { - let doc_value = serde_json::to_value(doc).map_err(VectorStoreError::JsonError)?; + .map(|Reverse(RankingItem(distance, id, doc, _))| { Ok(( distance.0, - doc.id.clone(), - serde_json::from_value(doc_value).map_err(VectorStoreError::JsonError)?, + id.clone(), + serde_json::from_str( + &serde_json::to_string(doc).map_err(VectorStoreError::JsonError)?, + ) + .map_err(VectorStoreError::JsonError)?, )) }) .collect::<Result<Vec<_>, _>>() @@ -278,7 +179,7 @@ impl<M: EmbeddingModel + std::marker::Sync> VectorStoreIndex for InMemoryVectorI // Return n best docs.into_iter() - .map(|Reverse(RankingItem(distance, _, doc, _))| Ok((distance.0, doc.id.clone()))) + .map(|Reverse(RankingItem(distance, id, _, _))| Ok((distance.0, id.clone()))) .collect::<Result<Vec<_>, _>>() } } diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index b07d348a..8f89d5f1 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -1,8 +1,8 @@ use futures::future::BoxFuture; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::Value; -use crate::embeddings::{DocumentEmbeddings, EmbeddingError}; +use crate::embeddings::{Embedding, EmbeddingError}; pub mod in_memory_store; @@ -20,33 +20,27 @@ pub enum VectorStoreError { } /// Trait for vector stores -pub trait VectorStore: Send + Sync { +pub trait VectorStore<D: Serialize>: Send + Sync { /// Query type for the vector store type Q; /// Add a list of documents to the vector store fn add_documents( &mut self, - documents: Vec<DocumentEmbeddings>, + documents: Vec<(String, D, Vec<Embedding>)>, ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + Send; /// Get the embeddings of a document by its id fn get_document_embeddings( &self, id: &str, - ) -> impl std::future::Future<Output = Result<Option<DocumentEmbeddings>, VectorStoreError>> + Send; - - /// Get the document by its id and deserialize it into the given type - fn get_document<T: for<'a> Deserialize<'a>>( - &self, - id: &str, - ) -> impl std::future::Future<Output = Result<Option<T>, VectorStoreError>> + Send; + ) -> impl std::future::Future<Output = Result<Option<D>, VectorStoreError>> + Send; /// Get the document by a query and deserialize it into the given type fn get_document_by_query( &self, query: Self::Q, - ) -> impl std::future::Future<Output = Result<Option<DocumentEmbeddings>, VectorStoreError>> + Send; + ) -> impl std::future::Future<Output = Result<Option<D>, VectorStoreError>> + Send; } /// Trait for vector store indexes diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 18e0873a..ece3e093 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -18,7 +18,7 @@ struct FakeDefinition { definition: String, } -// Shape of the document to be stored in MongoDB. +// Shape of the document to be stored in MongoDB, with embeddings. #[derive(Serialize, Debug)] struct Document { #[serde(rename = "_id")] From 5407772c2433d13dd003f131f65d673af44d32b2 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 9 Oct 2024 11:48:49 -0400 Subject: [PATCH 10/91] refactor: remove DocumentEmbeddings from in memory vector store --- rig-core/examples/calculator_chatbot.rs | 20 +- rig-core/examples/rag.rs | 17 +- rig-core/examples/rag_dynamic_tools.rs | 20 +- rig-core/examples/vector_search.rs | 20 +- rig-core/examples/vector_search_cohere.rs | 17 +- rig-core/src/vector_store/in_memory_store.rs | 206 +++++------------- rig-core/src/vector_store/mod.rs | 32 +-- rig-mongodb/examples/vector_search_mongodb.rs | 9 +- rig-mongodb/src/lib.rs | 63 +----- 9 files changed, 134 insertions(+), 270 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 04d26dc3..90e94a93 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -2,10 +2,10 @@ use anyhow::Result; use rig::{ cli_chatbot::cli_chatbot, completion::ToolDefinition, - embeddings::EmbeddingsBuilder, + embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, + vector_store::in_memory_store::InMemoryVectorStore, }; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -251,9 +251,19 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let mut store = InMemoryVectorStore::default(); - store.add_documents(embeddings).await?; - let index = store.index(embedding_model); + let vector_store = InMemoryVectorStore::default().add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )?; + let index = vector_store.index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source let calculator_rag = openai_client diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 3abd8ee9..b3363a43 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -2,9 +2,9 @@ use std::env; use rig::{ completion::Prompt, - embeddings::EmbeddingsBuilder, + embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, + vector_store::in_memory_store::InMemoryVectorStore, }; #[tokio::main] @@ -25,7 +25,18 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - vector_store.add_documents(embeddings).await?; + let vector_store = vector_store.add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )?; // Create vector store index let index = vector_store.index(embedding_model); diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 6e45730b..347a7bba 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -1,10 +1,10 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, - embeddings::EmbeddingsBuilder, + embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, + vector_store::in_memory_store::InMemoryVectorStore, }; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -150,9 +150,6 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - // Create vector store, compute tool embeddings and load them in the store - let mut vector_store = InMemoryVectorStore::default(); - let toolset = ToolSet::builder() .dynamic_tool(Add) .dynamic_tool(Subtract) @@ -163,7 +160,18 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - vector_store.add_documents(embeddings).await?; + let vector_store = InMemoryVectorStore::default().add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )?; // Create vector store index let index = vector_store.index(embedding_model); diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index e0d68bef..5cef0a37 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -3,7 +3,10 @@ use std::env; use rig::{ embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex}, + vector_store::{ + in_memory_store::{InMemoryVectorIndex, InMemoryVectorStore}, + VectorStoreIndex, + }, }; #[tokio::main] @@ -21,7 +24,20 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?; + let vector_store = InMemoryVectorStore::default().add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )?; + + let index = vector_store.index(model); let results = index .top_n::<DocumentEmbeddings>("What is a linglingdong?", 1) diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index a49ac231..13fd19ae 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -3,7 +3,7 @@ use std::env; use rig::{ embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::cohere::{Client, EMBED_ENGLISH_V3}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, + vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, }; #[tokio::main] @@ -15,8 +15,6 @@ async fn main() -> Result<(), anyhow::Error> { let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document"); let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query"); - let mut vector_store = InMemoryVectorStore::default(); - let embeddings = EmbeddingsBuilder::new(document_model) .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") @@ -24,7 +22,18 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - vector_store.add_documents(embeddings).await?; + let vector_store = InMemoryVectorStore::default().add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )?; let index = vector_store.index(search_model); diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index a5db505f..61d22117 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -7,28 +7,27 @@ use std::{ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; -use super::{VectorStore, VectorStoreError, VectorStoreIndex}; -use crate::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel, EmbeddingsBuilder}; +use super::{VectorStoreError, VectorStoreIndex}; +use crate::embeddings::{Embedding, EmbeddingModel}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. -#[derive(Clone, Default, Deserialize, Serialize)] -pub struct InMemoryVectorStore { +#[derive(Clone, Default)] +pub struct InMemoryVectorStore<D: Serialize> { /// The embeddings are stored in a HashMap with the document ID as the key. - embeddings: HashMap<String, DocumentEmbeddings>, + embeddings: HashMap<String, (D, Vec<Embedding>)>, } -impl InMemoryVectorStore { +impl<D: Serialize + Eq> InMemoryVectorStore<D> { /// Implement vector search on InMemoryVectorStore. /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for InMemoryVectorStore. - fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { + fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking<D> { // Sort documents by best embedding distance - let mut docs: EmbeddingRanking = BinaryHeap::new(); + let mut docs = BinaryHeap::new(); - for (id, doc_embeddings) in self.embeddings.iter() { + for (id, (doc, embeddings)) in self.embeddings.iter() { // Get the best context for the document given the prompt - if let Some((distance, embed_doc)) = doc_embeddings - .embeddings + if let Some((distance, embed_doc)) = embeddings .iter() .map(|embedding| { ( @@ -38,12 +37,7 @@ impl InMemoryVectorStore { }) .min_by(|a, b| a.0.cmp(&b.0)) { - docs.push(Reverse(RankingItem( - distance, - id, - doc_embeddings, - embed_doc, - ))); + docs.push(Reverse(RankingItem(distance, id, doc, embed_doc))); }; // If the heap size exceeds n, pop the least old element. @@ -63,77 +57,59 @@ impl InMemoryVectorStore { docs } -} - -/// RankingItem(distance, document_id, document, embed_doc) -#[derive(Eq, PartialEq)] -struct RankingItem<'a>( - OrderedFloat<f64>, - &'a String, - &'a DocumentEmbeddings, - &'a String, -); -impl Ord for RankingItem<'_> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.0.cmp(&other.0) - } -} - -impl PartialOrd for RankingItem<'_> { - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - Some(self.cmp(other)) - } -} - -type EmbeddingRanking<'a> = BinaryHeap<Reverse<RankingItem<'a>>>; - -impl VectorStore for InMemoryVectorStore { - type Q = (); - - async fn add_documents( - &mut self, - documents: Vec<DocumentEmbeddings>, - ) -> Result<(), VectorStoreError> { - for doc in documents { - self.embeddings.insert(doc.id.clone(), doc); + pub fn add_documents( + mut self, + documents: Vec<(String, D, Vec<Embedding>)>, + ) -> Result<Self, VectorStoreError> { + for (id, doc, embeddings) in documents { + self.embeddings.insert(id, (doc, embeddings)); } - Ok(()) + Ok(self) } - async fn get_document<T: for<'a> Deserialize<'a>>( + /// Get the document by its id and deserialize it into the given type + pub fn get_document<T: for<'a> Deserialize<'a>>( &self, id: &str, ) -> Result<Option<T>, VectorStoreError> { Ok(self .embeddings .get(id) - .map(|document| serde_json::from_value(document.document.clone())) + .map(|(doc, _)| serde_json::from_str(&serde_json::to_string(doc)?)) .transpose()?) } - async fn get_document_embeddings( - &self, - id: &str, - ) -> Result<Option<DocumentEmbeddings>, VectorStoreError> { - Ok(self.embeddings.get(id).cloned()) + pub fn get_document_embeddings(&self, id: &str) -> Result<Option<&D>, VectorStoreError> { + Ok(self.embeddings.get(id).map(|(doc, _)| doc)) } +} - async fn get_document_by_query( - &self, - _query: Self::Q, - ) -> Result<Option<DocumentEmbeddings>, VectorStoreError> { - Ok(None) +/// RankingItem(distance, document_id, document, embed_doc) +#[derive(Eq, PartialEq)] +struct RankingItem<'a, D: Serialize>(OrderedFloat<f64>, &'a String, &'a D, &'a String); + +impl<D: Serialize + Eq> Ord for RankingItem<'_, D> { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.cmp(&other.0) + } +} + +impl<D: Serialize + Eq> PartialOrd for RankingItem<'_, D> { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + Some(self.cmp(other)) } } -impl InMemoryVectorStore { - pub fn index<M: EmbeddingModel>(self, model: M) -> InMemoryVectorIndex<M> { +type EmbeddingRanking<'a, D> = BinaryHeap<Reverse<RankingItem<'a, D>>>; + +impl<D: Serialize> InMemoryVectorStore<D> { + pub fn index<M: EmbeddingModel>(self, model: M) -> InMemoryVectorIndex<M, D> { InMemoryVectorIndex::new(model, self) } - pub fn iter(&self) -> impl Iterator<Item = (&String, &DocumentEmbeddings)> { + pub fn iter(&self) -> impl Iterator<Item = (&String, &(D, Vec<Embedding>))> { self.embeddings.iter() } @@ -144,54 +120,19 @@ impl InMemoryVectorStore { pub fn is_empty(&self) -> bool { self.embeddings.is_empty() } - - /// Uitilty method to create an InMemoryVectorStore from a list of embeddings. - pub async fn from_embeddings( - embeddings: Vec<DocumentEmbeddings>, - ) -> Result<Self, VectorStoreError> { - let mut store = Self::default(); - store.add_documents(embeddings).await?; - Ok(store) - } - - /// Create an InMemoryVectorStore from a list of documents. - /// The documents are serialized to JSON and embedded using the provided embedding model. - /// The resulting embeddings are stored in an InMemoryVectorStore created by the method. - pub async fn from_documents<M: EmbeddingModel, T: Serialize>( - embedding_model: M, - documents: &[(String, T)], - ) -> Result<Self, VectorStoreError> { - let embeddings = documents - .iter() - .fold( - EmbeddingsBuilder::new(embedding_model), - |builder, (id, doc)| { - builder.json_document( - id, - serde_json::to_value(doc).expect("Document should be serializable"), - vec![serde_json::to_string(doc).expect("Document should be serializable")], - ) - }, - ) - .build() - .await?; - - let store = Self::from_embeddings(embeddings).await?; - Ok(store) - } } -pub struct InMemoryVectorIndex<M: EmbeddingModel> { +pub struct InMemoryVectorIndex<M: EmbeddingModel, D: Serialize> { model: M, - pub store: InMemoryVectorStore, + pub store: InMemoryVectorStore<D>, } -impl<M: EmbeddingModel> InMemoryVectorIndex<M> { - pub fn new(model: M, store: InMemoryVectorStore) -> Self { +impl<M: EmbeddingModel, D: Serialize> InMemoryVectorIndex<M, D> { + pub fn new(model: M, store: InMemoryVectorStore<D>) -> Self { Self { model, store } } - pub fn iter(&self) -> impl Iterator<Item = (&String, &DocumentEmbeddings)> { + pub fn iter(&self) -> impl Iterator<Item = (&String, &(D, Vec<Embedding>))> { self.store.iter() } @@ -202,49 +143,11 @@ impl<M: EmbeddingModel> InMemoryVectorIndex<M> { pub fn is_empty(&self) -> bool { self.store.is_empty() } - - /// Create an InMemoryVectorIndex from a list of documents. - /// The documents are serialized to JSON and embedded using the provided embedding model. - /// The resulting embeddings are stored in an InMemoryVectorStore created by the method. - /// The InMemoryVectorIndex is then created from the store and the provided query model. - pub async fn from_documents<T: Serialize>( - embedding_model: M, - query_model: M, - documents: &[(String, T)], - ) -> Result<Self, VectorStoreError> { - let mut store = InMemoryVectorStore::default(); - - let embeddings = documents - .iter() - .fold( - EmbeddingsBuilder::new(embedding_model), - |builder, (id, doc)| { - builder.json_document( - id, - serde_json::to_value(doc).expect("Document should be serializable"), - vec![serde_json::to_string(doc).expect("Document should be serializable")], - ) - }, - ) - .build() - .await?; - - store.add_documents(embeddings).await?; - Ok(store.index(query_model)) - } - - /// Utility method to create an InMemoryVectorIndex from a list of embeddings - /// and an embedding model. - pub async fn from_embeddings( - query_model: M, - embeddings: Vec<DocumentEmbeddings>, - ) -> Result<Self, VectorStoreError> { - let store = InMemoryVectorStore::from_embeddings(embeddings).await?; - Ok(store.index(query_model)) - } } -impl<M: EmbeddingModel + std::marker::Sync> VectorStoreIndex for InMemoryVectorIndex<M> { +impl<M: EmbeddingModel + std::marker::Sync, D: Serialize + Sync + Send + Eq> VectorStoreIndex + for InMemoryVectorIndex<M, D> +{ async fn top_n<T: for<'a> Deserialize<'a>>( &self, query: &str, @@ -256,12 +159,11 @@ impl<M: EmbeddingModel + std::marker::Sync> VectorStoreIndex for InMemoryVectorI // Return n best docs.into_iter() - .map(|Reverse(RankingItem(distance, _, doc, _))| { - let doc_value = serde_json::to_value(doc).map_err(VectorStoreError::JsonError)?; + .map(|Reverse(RankingItem(distance, id, doc, _))| { Ok(( distance.0, - doc.id.clone(), - serde_json::from_value(doc_value).map_err(VectorStoreError::JsonError)?, + id.clone(), + serde_json::from_str(&serde_json::to_string(doc)?)?, )) }) .collect::<Result<Vec<_>, _>>() @@ -278,7 +180,7 @@ impl<M: EmbeddingModel + std::marker::Sync> VectorStoreIndex for InMemoryVectorI // Return n best docs.into_iter() - .map(|Reverse(RankingItem(distance, _, doc, _))| Ok((distance.0, doc.id.clone()))) + .map(|Reverse(RankingItem(distance, id, _, _))| Ok((distance.0, id.clone()))) .collect::<Result<Vec<_>, _>>() } } diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index b07d348a..396b5514 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -2,7 +2,7 @@ use futures::future::BoxFuture; use serde::Deserialize; use serde_json::Value; -use crate::embeddings::{DocumentEmbeddings, EmbeddingError}; +use crate::embeddings::EmbeddingError; pub mod in_memory_store; @@ -19,36 +19,6 @@ pub enum VectorStoreError { DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync>), } -/// Trait for vector stores -pub trait VectorStore: Send + Sync { - /// Query type for the vector store - type Q; - - /// Add a list of documents to the vector store - fn add_documents( - &mut self, - documents: Vec<DocumentEmbeddings>, - ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + Send; - - /// Get the embeddings of a document by its id - fn get_document_embeddings( - &self, - id: &str, - ) -> impl std::future::Future<Output = Result<Option<DocumentEmbeddings>, VectorStoreError>> + Send; - - /// Get the document by its id and deserialize it into the given type - fn get_document<T: for<'a> Deserialize<'a>>( - &self, - id: &str, - ) -> impl std::future::Future<Output = Result<Option<T>, VectorStoreError>> + Send; - - /// Get the document by a query and deserialize it into the given type - fn get_document_by_query( - &self, - query: Self::Q, - ) -> impl std::future::Future<Output = Result<Option<DocumentEmbeddings>, VectorStoreError>> + Send; -} - /// Trait for vector store indexes pub trait VectorStoreIndex: Send + Sync { /// Get the top n documents based on the distance to the given query. diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 3d062de3..87f1595f 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -4,7 +4,7 @@ use std::env; use rig::{ embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{VectorStore, VectorStoreIndex}, + vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; @@ -29,8 +29,6 @@ async fn main() -> Result<(), anyhow::Error> { .database("knowledgebase") .collection("context"); - let mut vector_store = MongoDbVectorStore::new(collection); - // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); @@ -41,12 +39,13 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Add embeddings to vector store - match vector_store.add_documents(embeddings).await { + match collection.insert_many(embeddings, None).await { Ok(_) => println!("Documents added successfully"), Err(e) => println!("Error adding documents: {:?}", e), } + let vector_store = MongoDbVectorStore::new(collection); + // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings let index = vector_store.index(model, "vector_index", SearchParams::default()); diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 43869989..30dd9e95 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -3,7 +3,7 @@ use mongodb::bson::{self, doc}; use rig::{ embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel}, - vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, + vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; @@ -16,67 +16,6 @@ fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } -impl VectorStore for MongoDbVectorStore { - type Q = mongodb::bson::Document; - - async fn add_documents( - &mut self, - documents: Vec<DocumentEmbeddings>, - ) -> Result<(), VectorStoreError> { - self.collection - .insert_many(documents, None) - .await - .map_err(mongodb_to_rig_error)?; - Ok(()) - } - - async fn get_document_embeddings( - &self, - id: &str, - ) -> Result<Option<DocumentEmbeddings>, VectorStoreError> { - self.collection - .find_one(doc! { "_id": id }, None) - .await - .map_err(mongodb_to_rig_error) - } - - async fn get_document<T: for<'a> serde::Deserialize<'a>>( - &self, - id: &str, - ) -> Result<Option<T>, VectorStoreError> { - Ok(self - .collection - .clone_with_type::<String>() - .aggregate( - [ - doc! {"$match": { "_id": id}}, - doc! {"$project": { "document": 1 }}, - doc! {"$replaceRoot": { "newRoot": "$document" }}, - ], - None, - ) - .await - .map_err(mongodb_to_rig_error)? - .with_type::<String>() - .next() - .await - .transpose() - .map_err(mongodb_to_rig_error)? - .map(|doc| serde_json::from_str(&doc)) - .transpose()?) - } - - async fn get_document_by_query( - &self, - query: Self::Q, - ) -> Result<Option<DocumentEmbeddings>, VectorStoreError> { - self.collection - .find_one(query, None) - .await - .map_err(mongodb_to_rig_error) - } -} - impl MongoDbVectorStore { /// Create a new `MongoDbVectorStore` from a MongoDB collection. pub fn new(collection: mongodb::Collection<DocumentEmbeddings>) -> Self { From c61685e59a99c34639bc94864822c158005f4bba Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 9 Oct 2024 11:53:08 -0400 Subject: [PATCH 11/91] refactor(examples): combine vector store with vector store index --- rig-core/examples/calculator_chatbot.rs | 27 ++++++++++--------- rig-core/examples/rag.rs | 32 ++++++++++------------ rig-core/examples/rag_dynamic_tools.rs | 29 ++++++++++---------- rig-core/examples/vector_search.rs | 33 +++++++++++------------ rig-core/examples/vector_search_cohere.rs | 28 +++++++++---------- 5 files changed, 71 insertions(+), 78 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 90e94a93..fb168a08 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -251,19 +251,20 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let vector_store = InMemoryVectorStore::default().add_documents( - embeddings - .into_iter() - .map( - |DocumentEmbeddings { - id, - document, - embeddings, - }| { (id, document, embeddings) }, - ) - .collect(), - )?; - let index = vector_store.index(embedding_model); + let index = InMemoryVectorStore::default() + .add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )? + .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source let calculator_rag = openai_client diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index b3363a43..b4dee8a5 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -15,9 +15,6 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - // Create vector store, compute embeddings and load them in the store - let mut vector_store = InMemoryVectorStore::default(); - let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") @@ -25,21 +22,20 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let vector_store = vector_store.add_documents( - embeddings - .into_iter() - .map( - |DocumentEmbeddings { - id, - document, - embeddings, - }| { (id, document, embeddings) }, - ) - .collect(), - )?; - - // Create vector store index - let index = vector_store.index(embedding_model); + let index = InMemoryVectorStore::default() + .add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )? + .index(embedding_model); let rag_agent = openai_client.agent("gpt-4") .preamble(" diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 347a7bba..cdf6b65e 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -160,21 +160,20 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let vector_store = InMemoryVectorStore::default().add_documents( - embeddings - .into_iter() - .map( - |DocumentEmbeddings { - id, - document, - embeddings, - }| { (id, document, embeddings) }, - ) - .collect(), - )?; - - // Create vector store index - let index = vector_store.index(embedding_model); + let index = InMemoryVectorStore::default() + .add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )? + .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source let calculator_rag = openai_client diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 5cef0a37..39b4b763 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -3,10 +3,7 @@ use std::env; use rig::{ embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{ - in_memory_store::{InMemoryVectorIndex, InMemoryVectorStore}, - VectorStoreIndex, - }, + vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, }; #[tokio::main] @@ -24,20 +21,20 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let vector_store = InMemoryVectorStore::default().add_documents( - embeddings - .into_iter() - .map( - |DocumentEmbeddings { - id, - document, - embeddings, - }| { (id, document, embeddings) }, - ) - .collect(), - )?; - - let index = vector_store.index(model); + let index = InMemoryVectorStore::default() + .add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )? + .index(model); let results = index .top_n::<DocumentEmbeddings>("What is a linglingdong?", 1) diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 13fd19ae..f463cc69 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -22,20 +22,20 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let vector_store = InMemoryVectorStore::default().add_documents( - embeddings - .into_iter() - .map( - |DocumentEmbeddings { - id, - document, - embeddings, - }| { (id, document, embeddings) }, - ) - .collect(), - )?; - - let index = vector_store.index(search_model); + let index = InMemoryVectorStore::default() + .add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )? + .index(search_model); let results = index .top_n::<DocumentEmbeddings>("What is a linglingdong?", 1) From a15d493e895dc0b74451a5f3708e65164fdb386d Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 9 Oct 2024 12:01:27 -0400 Subject: [PATCH 12/91] docs: add and update docstrings --- rig-core/src/vector_store/in_memory_store.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 61d22117..ba2e754b 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -14,7 +14,9 @@ use crate::embeddings::{Embedding, EmbeddingModel}; /// in-memory using a HashMap. #[derive(Clone, Default)] pub struct InMemoryVectorStore<D: Serialize> { - /// The embeddings are stored in a HashMap with the document ID as the key. + /// The embeddings are stored in a HashMap. + /// Hashmap key is the document id. + /// Hashmap value is a tuple of the serializable document and its corresponding embeddings. embeddings: HashMap<String, (D, Vec<Embedding>)>, } @@ -50,7 +52,7 @@ impl<D: Serialize + Eq> InMemoryVectorStore<D> { tracing::info!(target: "rig", "Selected documents: {}", docs.iter() - .map(|Reverse(RankingItem(distance, id, _, _))| format!("{} ({})", id, distance)) + .map(|Reverse(RankingItem(distance, id, _, embed_doc))| format!("{} ({}). Specific match: {}", id, distance, embed_doc)) .collect::<Vec<String>>() .join(", ") ); @@ -58,6 +60,8 @@ impl<D: Serialize + Eq> InMemoryVectorStore<D> { docs } + /// Add documents to the store. + /// Returns the store with the added documents. pub fn add_documents( mut self, documents: Vec<(String, D, Vec<Embedding>)>, @@ -69,7 +73,7 @@ impl<D: Serialize + Eq> InMemoryVectorStore<D> { Ok(self) } - /// Get the document by its id and deserialize it into the given type + /// Get the document by its id and deserialize it into the given type. pub fn get_document<T: for<'a> Deserialize<'a>>( &self, id: &str, @@ -81,12 +85,13 @@ impl<D: Serialize + Eq> InMemoryVectorStore<D> { .transpose()?) } + /// Get the document embeddings by its id. pub fn get_document_embeddings(&self, id: &str) -> Result<Option<&D>, VectorStoreError> { Ok(self.embeddings.get(id).map(|(doc, _)| doc)) } } -/// RankingItem(distance, document_id, document, embed_doc) +/// RankingItem(distance, document_id, serializable document, embeddings document) #[derive(Eq, PartialEq)] struct RankingItem<'a, D: Serialize>(OrderedFloat<f64>, &'a String, &'a D, &'a String); From 46b468070448aaff4201fa2ec7b90d1b83dbb9e7 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 9 Oct 2024 12:16:01 -0400 Subject: [PATCH 13/91] fix (examples): fix bugs in examples --- rig-core/examples/vector_search.rs | 4 ++-- rig-core/examples/vector_search_cohere.rs | 4 ++-- rig-mongodb/examples/vector_search_mongodb.rs | 4 +--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 39b4b763..45110606 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -37,10 +37,10 @@ async fn main() -> Result<(), anyhow::Error> { .index(model); let results = index - .top_n::<DocumentEmbeddings>("What is a linglingdong?", 1) + .top_n::<String>("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.document)) + .map(|(score, id, doc)| (score, id, doc)) .collect::<Vec<_>>(); println!("Results: {:?}", results); diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index f463cc69..1e0180d3 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -38,10 +38,10 @@ async fn main() -> Result<(), anyhow::Error> { .index(search_model); let results = index - .top_n::<DocumentEmbeddings>("What is a linglingdong?", 1) + .top_n::<String>("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.document)) + .map(|(score, id, doc)| (score, id, doc)) .collect::<Vec<_>>(); println!("Results: {:?}", results); diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 87f1595f..00668091 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -44,11 +44,9 @@ async fn main() -> Result<(), anyhow::Error> { Err(e) => println!("Error adding documents: {:?}", e), } - let vector_store = MongoDbVectorStore::new(collection); - // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model, "vector_index", SearchParams::default()); + let index = MongoDbVectorStore::new(collection).index(model, "vector_index", SearchParams::default()); // Query the index let results = index From fe75da1b0b7b69cc05074457505978a6b9d8bbb0 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 9 Oct 2024 12:17:43 -0400 Subject: [PATCH 14/91] style: cargo fmt --- rig-mongodb/examples/vector_search_mongodb.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 00668091..0d31aaa2 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -46,7 +46,8 @@ async fn main() -> Result<(), anyhow::Error> { // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = MongoDbVectorStore::new(collection).index(model, "vector_index", SearchParams::default()); + let index = + MongoDbVectorStore::new(collection).index(model, "vector_index", SearchParams::default()); // Query the index let results = index From 6c7ab8d6f344a773add6aa6156091e9033e8725c Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 9 Oct 2024 13:39:13 -0400 Subject: [PATCH 15/91] revert: revert vector store to main --- rig-core/examples/rag.rs | 48 +---- rig-core/examples/vector_search.rs | 59 +----- rig-core/examples/vector_search_cohere.rs | 57 +----- rig-core/src/vector_store/in_memory_store.rs | 179 ++++++++++++++----- rig-core/src/vector_store/mod.rs | 18 +- 5 files changed, 172 insertions(+), 189 deletions(-) diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 55a99e63..3abd8ee9 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -6,15 +6,6 @@ use rig::{ providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, }; -use rig_derive::Embed; -use serde::Serialize; - -#[derive(Embed, Clone, Serialize, Eq, PartialEq, Default)] -struct FakeDefinition { - id: String, - #[embed] - definitions: Vec<String>, -} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -27,45 +18,14 @@ async fn main() -> Result<(), anyhow::Error> { // Create vector store, compute embeddings and load them in the store let mut vector_store = InMemoryVectorStore::default(); - let fake_definitions = vec![ - FakeDefinition { - id: "doc0".to_string(), - definitions: vec![ - "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), - "Definition of a *flurbo*: A unit of currency used in a bizarre or fantastical world, often associated with eccentric societies or sci-fi settings.".to_string() - ] - }, - FakeDefinition { - id: "doc1".to_string(), - definitions: vec![ - "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "Definition of a *glarb-glarb*: A mysterious, bubbling substance often found in swamps, alien planets, or under mysterious circumstances.".to_string() - ] - }, - FakeDefinition { - id: "doc2".to_string(), - definitions: vec![ - "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() - ] - } - ]; - let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(fake_definitions) + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") .build() .await?; - vector_store - .add_documents( - embeddings - .into_iter() - .enumerate() - .map(|(i, (fake_definition, embeddings))| { - (format!("doc{i}"), fake_definition, embeddings) - }) - .collect(), - ) - .await?; + vector_store.add_documents(embeddings).await?; // Create vector store index let index = vector_store.index(embedding_model); diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 095ca296..e0d68bef 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,19 +1,10 @@ use std::env; use rig::{ - embeddings::EmbeddingsBuilder, + embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, + vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex}, }; -use rig_derive::Embed; -use serde::{Deserialize, Serialize}; - -#[derive(Embed, Clone, Serialize, Default, Eq, PartialEq, Deserialize, Debug)] -struct FakeDefinition { - id: String, - #[embed] - definitions: Vec<String>, -} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -23,54 +14,20 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - let fake_definitions = vec![ - FakeDefinition { - id: "doc0".to_string(), - definitions: vec![ - "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), - "Definition of a *flurbo*: A unit of currency used in a bizarre or fantastical world, often associated with eccentric societies or sci-fi settings.".to_string() - ] - }, - FakeDefinition { - id: "doc1".to_string(), - definitions: vec![ - "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "Definition of a *glarb-glarb*: A mysterious, bubbling substance often found in swamps, alien planets, or under mysterious circumstances.".to_string() - ] - }, - FakeDefinition { - id: "doc2".to_string(), - definitions: vec![ - "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() - ] - } - ]; - let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions) + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") .build() .await?; - let mut store = InMemoryVectorStore::default(); - store - .add_documents( - embeddings - .into_iter() - .enumerate() - .map(|(i, (fake_definition, embeddings))| { - (format!("doc{i}"), fake_definition, embeddings) - }) - .collect(), - ) - .await?; - - let index = store.index(model); + let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?; let results = index - .top_n::<FakeDefinition>("What is a linglingdong?", 1) + .top_n::<DocumentEmbeddings>("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc)) + .map(|(score, id, doc)| (score, id, doc.document)) .collect::<Vec<_>>(); println!("Results: {:?}", results); diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 6be094b4..a49ac231 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,19 +1,10 @@ use std::env; use rig::{ - embeddings::EmbeddingsBuilder, + embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, }; -use rig_derive::Embed; -use serde::{Deserialize, Serialize}; - -#[derive(Embed, Clone, Serialize, Default, Eq, PartialEq, Deserialize, Debug)] -struct FakeDefinition { - id: String, - #[embed] - definitions: Vec<String>, -} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -24,54 +15,24 @@ async fn main() -> Result<(), anyhow::Error> { let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document"); let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query"); - let fake_definitions = vec![ - FakeDefinition { - id: "doc0".to_string(), - definitions: vec![ - "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), - "Definition of a *flurbo*: A unit of currency used in a bizarre or fantastical world, often associated with eccentric societies or sci-fi settings.".to_string() - ] - }, - FakeDefinition { - id: "doc1".to_string(), - definitions: vec![ - "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "Definition of a *glarb-glarb*: A mysterious, bubbling substance often found in swamps, alien planets, or under mysterious circumstances.".to_string() - ] - }, - FakeDefinition { - id: "doc2".to_string(), - definitions: vec![ - "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() - ] - } - ]; + let mut vector_store = InMemoryVectorStore::default(); let embeddings = EmbeddingsBuilder::new(document_model) - .documents(fake_definitions) + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") .build() .await?; - let mut store = InMemoryVectorStore::default(); - store - .add_documents( - embeddings - .into_iter() - .enumerate() - .map(|(i, (fake_definition, embeddings))| { - (format!("doc{i}"), fake_definition, embeddings) - }) - .collect(), - ) - .await?; + vector_store.add_documents(embeddings).await?; - let index = store.index(search_model); + let index = vector_store.index(search_model); let results = index - .top_n::<FakeDefinition>("What is a linglingdong?", 1) + .top_n::<DocumentEmbeddings>("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc)) + .map(|(score, id, doc)| (score, id, doc.document)) .collect::<Vec<_>>(); println!("Results: {:?}", results); diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 5ed98671..a5db505f 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -8,26 +8,27 @@ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; use super::{VectorStore, VectorStoreError, VectorStoreIndex}; -use crate::embeddings::{Embedding, EmbeddingModel}; +use crate::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel, EmbeddingsBuilder}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. -#[derive(Clone, Default)] -pub struct InMemoryVectorStore<D: Serialize> { +#[derive(Clone, Default, Deserialize, Serialize)] +pub struct InMemoryVectorStore { /// The embeddings are stored in a HashMap with the document ID as the key. - embeddings: HashMap<String, (D, Vec<Embedding>)>, + embeddings: HashMap<String, DocumentEmbeddings>, } -impl<D: Serialize + Eq> InMemoryVectorStore<D> { +impl InMemoryVectorStore { /// Implement vector search on InMemoryVectorStore. /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for InMemoryVectorStore. - fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking<D> { + fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { // Sort documents by best embedding distance - let mut docs = BinaryHeap::new(); + let mut docs: EmbeddingRanking = BinaryHeap::new(); - for (id, (doc, embeddings)) in self.embeddings.iter() { + for (id, doc_embeddings) in self.embeddings.iter() { // Get the best context for the document given the prompt - if let Some((distance, embed_doc)) = embeddings + if let Some((distance, embed_doc)) = doc_embeddings + .embeddings .iter() .map(|embedding| { ( @@ -37,7 +38,12 @@ impl<D: Serialize + Eq> InMemoryVectorStore<D> { }) .min_by(|a, b| a.0.cmp(&b.0)) { - docs.push(Reverse(RankingItem(distance, id, doc, embed_doc))); + docs.push(Reverse(RankingItem( + distance, + id, + doc_embeddings, + embed_doc, + ))); }; // If the heap size exceeds n, pop the least old element. @@ -61,51 +67,73 @@ impl<D: Serialize + Eq> InMemoryVectorStore<D> { /// RankingItem(distance, document_id, document, embed_doc) #[derive(Eq, PartialEq)] -struct RankingItem<'a, D: Serialize>(OrderedFloat<f64>, &'a String, &'a D, &'a String); - -impl<D: Serialize + Eq> Ord for RankingItem<'_, D> { +struct RankingItem<'a>( + OrderedFloat<f64>, + &'a String, + &'a DocumentEmbeddings, + &'a String, +); + +impl Ord for RankingItem<'_> { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.0.cmp(&other.0) } } -impl<D: Serialize + Eq> PartialOrd for RankingItem<'_, D> { +impl PartialOrd for RankingItem<'_> { fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { Some(self.cmp(other)) } } -type EmbeddingRanking<'a, D> = BinaryHeap<Reverse<RankingItem<'a, D>>>; +type EmbeddingRanking<'a> = BinaryHeap<Reverse<RankingItem<'a>>>; -impl<D: Serialize + Send + Sync + Clone> VectorStore<D> for InMemoryVectorStore<D> { +impl VectorStore for InMemoryVectorStore { type Q = (); async fn add_documents( &mut self, - documents: Vec<(String, D, Vec<Embedding>)>, + documents: Vec<DocumentEmbeddings>, ) -> Result<(), VectorStoreError> { - for (id, doc, embeddings) in documents { - self.embeddings.insert(id, (doc, embeddings)); + for doc in documents { + self.embeddings.insert(doc.id.clone(), doc); } Ok(()) } - async fn get_document_embeddings(&self, id: &str) -> Result<Option<D>, VectorStoreError> { - Ok(self.embeddings.get(id).cloned().map(|(doc, _)| doc)) + async fn get_document<T: for<'a> Deserialize<'a>>( + &self, + id: &str, + ) -> Result<Option<T>, VectorStoreError> { + Ok(self + .embeddings + .get(id) + .map(|document| serde_json::from_value(document.document.clone())) + .transpose()?) + } + + async fn get_document_embeddings( + &self, + id: &str, + ) -> Result<Option<DocumentEmbeddings>, VectorStoreError> { + Ok(self.embeddings.get(id).cloned()) } - async fn get_document_by_query(&self, _query: Self::Q) -> Result<Option<D>, VectorStoreError> { + async fn get_document_by_query( + &self, + _query: Self::Q, + ) -> Result<Option<DocumentEmbeddings>, VectorStoreError> { Ok(None) } } -impl<D: Serialize> InMemoryVectorStore<D> { - pub fn index<M: EmbeddingModel>(self, model: M) -> InMemoryVectorIndex<M, D> { +impl InMemoryVectorStore { + pub fn index<M: EmbeddingModel>(self, model: M) -> InMemoryVectorIndex<M> { InMemoryVectorIndex::new(model, self) } - pub fn iter(&self) -> impl Iterator<Item = (&String, &(D, Vec<Embedding>))> { + pub fn iter(&self) -> impl Iterator<Item = (&String, &DocumentEmbeddings)> { self.embeddings.iter() } @@ -116,19 +144,54 @@ impl<D: Serialize> InMemoryVectorStore<D> { pub fn is_empty(&self) -> bool { self.embeddings.is_empty() } + + /// Uitilty method to create an InMemoryVectorStore from a list of embeddings. + pub async fn from_embeddings( + embeddings: Vec<DocumentEmbeddings>, + ) -> Result<Self, VectorStoreError> { + let mut store = Self::default(); + store.add_documents(embeddings).await?; + Ok(store) + } + + /// Create an InMemoryVectorStore from a list of documents. + /// The documents are serialized to JSON and embedded using the provided embedding model. + /// The resulting embeddings are stored in an InMemoryVectorStore created by the method. + pub async fn from_documents<M: EmbeddingModel, T: Serialize>( + embedding_model: M, + documents: &[(String, T)], + ) -> Result<Self, VectorStoreError> { + let embeddings = documents + .iter() + .fold( + EmbeddingsBuilder::new(embedding_model), + |builder, (id, doc)| { + builder.json_document( + id, + serde_json::to_value(doc).expect("Document should be serializable"), + vec![serde_json::to_string(doc).expect("Document should be serializable")], + ) + }, + ) + .build() + .await?; + + let store = Self::from_embeddings(embeddings).await?; + Ok(store) + } } -pub struct InMemoryVectorIndex<M: EmbeddingModel, D: Serialize> { +pub struct InMemoryVectorIndex<M: EmbeddingModel> { model: M, - pub store: InMemoryVectorStore<D>, + pub store: InMemoryVectorStore, } -impl<M: EmbeddingModel, D: Serialize> InMemoryVectorIndex<M, D> { - pub fn new(model: M, store: InMemoryVectorStore<D>) -> Self { +impl<M: EmbeddingModel> InMemoryVectorIndex<M> { + pub fn new(model: M, store: InMemoryVectorStore) -> Self { Self { model, store } } - pub fn iter(&self) -> impl Iterator<Item = (&String, &(D, Vec<Embedding>))> { + pub fn iter(&self) -> impl Iterator<Item = (&String, &DocumentEmbeddings)> { self.store.iter() } @@ -139,11 +202,49 @@ impl<M: EmbeddingModel, D: Serialize> InMemoryVectorIndex<M, D> { pub fn is_empty(&self) -> bool { self.store.is_empty() } + + /// Create an InMemoryVectorIndex from a list of documents. + /// The documents are serialized to JSON and embedded using the provided embedding model. + /// The resulting embeddings are stored in an InMemoryVectorStore created by the method. + /// The InMemoryVectorIndex is then created from the store and the provided query model. + pub async fn from_documents<T: Serialize>( + embedding_model: M, + query_model: M, + documents: &[(String, T)], + ) -> Result<Self, VectorStoreError> { + let mut store = InMemoryVectorStore::default(); + + let embeddings = documents + .iter() + .fold( + EmbeddingsBuilder::new(embedding_model), + |builder, (id, doc)| { + builder.json_document( + id, + serde_json::to_value(doc).expect("Document should be serializable"), + vec![serde_json::to_string(doc).expect("Document should be serializable")], + ) + }, + ) + .build() + .await?; + + store.add_documents(embeddings).await?; + Ok(store.index(query_model)) + } + + /// Utility method to create an InMemoryVectorIndex from a list of embeddings + /// and an embedding model. + pub async fn from_embeddings( + query_model: M, + embeddings: Vec<DocumentEmbeddings>, + ) -> Result<Self, VectorStoreError> { + let store = InMemoryVectorStore::from_embeddings(embeddings).await?; + Ok(store.index(query_model)) + } } -impl<M: EmbeddingModel + std::marker::Sync, D: Serialize + Sync + Send + Eq> VectorStoreIndex - for InMemoryVectorIndex<M, D> -{ +impl<M: EmbeddingModel + std::marker::Sync> VectorStoreIndex for InMemoryVectorIndex<M> { async fn top_n<T: for<'a> Deserialize<'a>>( &self, query: &str, @@ -155,14 +256,12 @@ impl<M: EmbeddingModel + std::marker::Sync, D: Serialize + Sync + Send + Eq> Vec // Return n best docs.into_iter() - .map(|Reverse(RankingItem(distance, id, doc, _))| { + .map(|Reverse(RankingItem(distance, _, doc, _))| { + let doc_value = serde_json::to_value(doc).map_err(VectorStoreError::JsonError)?; Ok(( distance.0, - id.clone(), - serde_json::from_str( - &serde_json::to_string(doc).map_err(VectorStoreError::JsonError)?, - ) - .map_err(VectorStoreError::JsonError)?, + doc.id.clone(), + serde_json::from_value(doc_value).map_err(VectorStoreError::JsonError)?, )) }) .collect::<Result<Vec<_>, _>>() @@ -179,7 +278,7 @@ impl<M: EmbeddingModel + std::marker::Sync, D: Serialize + Sync + Send + Eq> Vec // Return n best docs.into_iter() - .map(|Reverse(RankingItem(distance, id, _, _))| Ok((distance.0, id.clone()))) + .map(|Reverse(RankingItem(distance, _, doc, _))| Ok((distance.0, doc.id.clone()))) .collect::<Result<Vec<_>, _>>() } } diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 8f89d5f1..b07d348a 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -1,8 +1,8 @@ use futures::future::BoxFuture; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use serde_json::Value; -use crate::embeddings::{Embedding, EmbeddingError}; +use crate::embeddings::{DocumentEmbeddings, EmbeddingError}; pub mod in_memory_store; @@ -20,27 +20,33 @@ pub enum VectorStoreError { } /// Trait for vector stores -pub trait VectorStore<D: Serialize>: Send + Sync { +pub trait VectorStore: Send + Sync { /// Query type for the vector store type Q; /// Add a list of documents to the vector store fn add_documents( &mut self, - documents: Vec<(String, D, Vec<Embedding>)>, + documents: Vec<DocumentEmbeddings>, ) -> impl std::future::Future<Output = Result<(), VectorStoreError>> + Send; /// Get the embeddings of a document by its id fn get_document_embeddings( &self, id: &str, - ) -> impl std::future::Future<Output = Result<Option<D>, VectorStoreError>> + Send; + ) -> impl std::future::Future<Output = Result<Option<DocumentEmbeddings>, VectorStoreError>> + Send; + + /// Get the document by its id and deserialize it into the given type + fn get_document<T: for<'a> Deserialize<'a>>( + &self, + id: &str, + ) -> impl std::future::Future<Output = Result<Option<T>, VectorStoreError>> + Send; /// Get the document by a query and deserialize it into the given type fn get_document_by_query( &self, query: Self::Q, - ) -> impl std::future::Future<Output = Result<Option<D>, VectorStoreError>> + Send; + ) -> impl std::future::Future<Output = Result<Option<DocumentEmbeddings>, VectorStoreError>> + Send; } /// Trait for vector store indexes From bb712e3fdabfffbf4a4e6c35fb2eec951ff5c83f Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 9 Oct 2024 15:49:00 -0400 Subject: [PATCH 16/91] docs: update emebddings builder docstrings --- rig-core/src/embeddings.rs | 166 ++++++++++++++++++++++++++++++------- 1 file changed, 136 insertions(+), 30 deletions(-) diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index 458c6eb3..d68ae565 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -8,31 +8,66 @@ //! struct, which allows users to build collections of document embeddings using different embedding //! models and document sources. //! -//! The module also defines the [Embedding] struct, which represents a single document embedding, -//! and the [DocumentEmbeddings] struct, which represents a document along with its associated -//! embeddings. These structs are used to store and manipulate collections of document embeddings. +//! The module also defines the [Embedding] struct, which represents a single document embedding. +//! +//! The module also defines the [Embeddable] trait, which represents types that can be embedded. +//! Only types that implement the Embeddable trait can be used with the EmbeddingsBuilder. //! //! Finally, the module defines the [EmbeddingError] enum, which represents various errors that //! can occur during embedding generation or processing. //! //! # Example //! ```rust -//! use rig::providers::openai::{Client, self}; -//! use rig::embeddings::{EmbeddingModel, EmbeddingsBuilder}; +//! use std::env; +//! +//! use rig::{ +//! embeddings::EmbeddingsBuilder, +//! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, +//! }; +//! use rig_derive::Embed; //! -//! // Initialize the OpenAI client -//! let openai = Client::new("your-openai-api-key"); +//! #[derive(Embed)] +//! struct FakeDefinition { +//! id: String, +//! word: String, +//! #[embed] +//! definitions: Vec<String>, +//! } +//! // Create OpenAI client +//! let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); +//! let openai_client = Client::new(&openai_api_key); //! -//! // Create an instance of the `text-embedding-ada-002` model -//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002); +//! let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); //! -//! // Create an embeddings builder and add documents -//! let embeddings = EmbeddingsBuilder::new(embedding_model) -//! .simple_document("doc1", "This is the first document.") -//! .simple_document("doc2", "This is the second document.") -//! .build() -//! .await -//! .expect("Failed to build embeddings."); +//! let embeddings = EmbeddingsBuilder::new(model.clone()) +//! .documents(vec![ +//! FakeDefinition { +//! id: "doc0".to_string(), +//! word: "flurbo".to_string(), +//! definitions: vec![ +//! "A green alien that lives on cold planets.".to_string(), +//! "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() +//! ] +//! }, +//! FakeDefinition { +//! id: "doc1".to_string(), +//! word: "glarb-glarb".to_string(), +//! definitions: vec![ +//! "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), +//! "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() +//! ] +//! }, +//! FakeDefinition { +//! id: "doc2".to_string(), +//! word: "linglingdong".to_string(), +//! definitions: vec![ +//! "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), +//! "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() +//! ] +//! }, +//! ]) +//! .build() +//! .await?; //! //! // Use the generated embeddings //! // ... @@ -102,7 +137,7 @@ pub trait EmbeddingModel: Clone + Sync + Send { /// Struct that holds a single document and its embedding. #[derive(Clone, Default, Deserialize, Serialize, Debug)] pub struct Embedding { - /// The document that was embedded + /// The document that was embedded. Used for debugging. pub document: String, /// The embedding vector pub vec: Vec<f64>, @@ -159,6 +194,7 @@ impl<M: EmbeddingModel, D: Embeddable<Kind = K>, K: EmbeddingKind> EmbeddingsBui } } + /// Add a document that implements `Embeddable` to the builder. pub fn document(mut self, document: D) -> Self { let embed_targets = document.embeddable(); @@ -166,6 +202,7 @@ impl<M: EmbeddingModel, D: Embeddable<Kind = K>, K: EmbeddingKind> EmbeddingsBui self } + /// Add many documents that implement `Embeddable` to the builder. pub fn documents(mut self, documents: Vec<D>) -> EmbeddingsBuilder<M, D, D::Kind> { documents.into_iter().for_each(|doc| { let embed_targets = doc.embeddable(); @@ -180,7 +217,11 @@ impl<M: EmbeddingModel, D: Embeddable<Kind = K>, K: EmbeddingKind> EmbeddingsBui impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M, D, ManyEmbedding> { + /// Generate embeddings for all documents in the builder. + /// The method only applies when documents in the builder each contain multiple embedding targets. + /// Returns a vector of tuples, where the first element is the document and the second element is the vector of embeddings. pub async fn build(&self) -> Result<Vec<(D, Vec<Embedding>)>, EmbeddingError> { + // Use this for reference later to merge a document back with its embeddings. let documents_map = self .documents .clone() @@ -189,22 +230,19 @@ impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> .map(|(id, (document, _))| (id, document)) .collect::<HashMap<_, _>>(); - let embeddings = stream::iter(self.documents.clone().into_iter().enumerate()) - // Flatten the documents + let embeddings = stream::iter(self.documents.iter().enumerate()) + // Merge the embedding targets of each document into a single list of embedding targets. .flat_map(|(i, (_, embed_targets))| { - stream::iter( - embed_targets - .into_iter() - .map(move |target| (i, target.clone())), - ) + stream::iter(embed_targets.iter().map(move |target| (i, target.clone()))) }) - // Chunk them into N (the emebdding API limit per request) + // Chunk them into N (the emebdding API limit per request). .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings + // Generate the embeddings for a chunk at a time. .map(|docs| async { - let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + let (document_indices, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + Ok::<_, EmbeddingError>( - documents + document_indices .into_iter() .zip(self.model.embed_documents(embed_targets).await?.into_iter()) .collect::<Vec<_>>(), @@ -213,8 +251,6 @@ impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> .boxed() // Parallelize the embeddings generation over 10 concurrent requests .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - // .try_collect::<Vec<_>>() - // .await; .try_fold( HashMap::new(), |mut acc: HashMap<_, Vec<_>>, embeddings| async move { @@ -242,6 +278,9 @@ impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M, D, SingleEmbedding> { + /// Generate embeddings for all documents in the builder. + /// The method only applies when documents in the builder each contain a single embedding target. + /// Returns a vector of tuples, where the first element is the document and the second element is the embedding. pub async fn build(&self) -> Result<Vec<(D, Embedding)>, EmbeddingError> { let embeddings = stream::iter( self.documents @@ -274,6 +313,9 @@ impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> } } +////////////////////////////////////////////////////// +/// Implementations of Embeddable for common types /// +////////////////////////////////////////////////////// impl Embeddable for String { type Kind = SingleEmbedding; @@ -282,6 +324,22 @@ impl Embeddable for String { } } +impl Embeddable for i8 { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec<String> { + vec![self.to_string()] + } +} + +impl Embeddable for i16 { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec<String> { + vec![self.to_string()] + } +} + impl Embeddable for i32 { type Kind = SingleEmbedding; @@ -290,6 +348,54 @@ impl Embeddable for i32 { } } +impl Embeddable for i64 { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec<String> { + vec![self.to_string()] + } +} + +impl Embeddable for i128 { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec<String> { + vec![self.to_string()] + } +} + +impl Embeddable for f32 { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec<String> { + vec![self.to_string()] + } +} + +impl Embeddable for f64 { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec<String> { + vec![self.to_string()] + } +} + +impl Embeddable for bool { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec<String> { + vec![self.to_string()] + } +} + +impl Embeddable for char { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec<String> { + vec![self.to_string()] + } +} + impl<T: Embeddable> Embeddable for Vec<T> { type Kind = ManyEmbedding; From efa2b65427fddcab60734fa023335c38022047bf Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 10 Oct 2024 14:41:57 -0400 Subject: [PATCH 17/91] refactor: derive macro --- Cargo.lock | 1 + rig-core/examples/rag.rs | 2 +- rig-core/examples/vector_search.rs | 2 +- rig-core/examples/vector_search_cohere.rs | 7 +- rig-core/rig-core-derive/Cargo.toml | 1 + rig-core/rig-core-derive/src/custom.rs | 98 ++++++++++ rig-core/rig-core-derive/src/embeddable.rs | 183 ++++++------------ rig-core/rig-core-derive/src/lib.rs | 5 + rig-core/src/embeddings.rs | 30 ++- rig-lancedb/examples/fixtures/lib.rs | 2 +- rig-mongodb/examples/vector_search_mongodb.rs | 2 +- 11 files changed, 197 insertions(+), 136 deletions(-) create mode 100644 rig-core/rig-core-derive/src/custom.rs diff --git a/Cargo.lock b/Cargo.lock index 311b2862..c47f71b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4012,6 +4012,7 @@ name = "rig-derive" version = "0.1.0" dependencies = [ "indoc", + "proc-macro2", "quote", "syn 2.0.79", ] diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index d24e293f..82e9d4fd 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -2,7 +2,7 @@ use std::{env, vec}; use rig::{ completion::Prompt, - embeddings::EmbeddingsBuilder, + embeddings::{Embeddable, EmbeddingsBuilder, ManyEmbedding}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, }; diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 8e23d274..523719ac 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - embeddings::EmbeddingsBuilder, + embeddings::{Embeddable, EmbeddingsBuilder, ManyEmbedding}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, }; diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 94ca4e4b..f9f84175 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - embeddings::EmbeddingsBuilder, + embeddings::{Embeddable, EmbeddingsBuilder, ManyEmbedding}, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, }; @@ -69,7 +69,10 @@ async fn main() -> Result<(), anyhow::Error> { .index(search_model); let results = index - .top_n::<FakeDefinition>("Which instrument is found in the Nebulon Mountain Ranges?", 1) + .top_n::<FakeDefinition>( + "Which instrument is found in the Nebulon Mountain Ranges?", + 1, + ) .await? .into_iter() .map(|(score, id, doc)| (score, id, doc.word)) diff --git a/rig-core/rig-core-derive/Cargo.toml b/rig-core/rig-core-derive/Cargo.toml index 008f492b..1ab5e5ac 100644 --- a/rig-core/rig-core-derive/Cargo.toml +++ b/rig-core/rig-core-derive/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] indoc = "2.0.5" +proc-macro2 = { version = "1.0.87", features = ["proc-macro"] } quote = "1.0.37" syn = { version = "2.0.79", features = ["full"]} diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs new file mode 100644 index 00000000..eb6478bd --- /dev/null +++ b/rig-core/rig-core-derive/src/custom.rs @@ -0,0 +1,98 @@ +use quote::ToTokens; +use syn::{meta::ParseNestedMeta, spanned::Spanned, ExprPath}; + +use crate::EMBED; + +const EMBED_WITH: &str = "embed_with"; + +pub(crate) trait CustomAttributeParser { + // Determine if field is tagged with an #[embed(embed_with = "...")] attribute. + fn is_custom(&self) -> syn::Result<bool>; + + // Get the "..." part of the #[embed(embed_with = "...")] attribute. + // Ex: If attribute is tagged with #[embed(embed_with = "my_embed")], returns "my_embed". + fn expand_tag(&self) -> syn::Result<syn::ExprPath>; +} + +impl CustomAttributeParser for syn::Attribute { + fn is_custom(&self) -> syn::Result<bool> { + // Check that the attribute is a list. + match &self.meta { + syn::Meta::List(meta) => { + if meta.tokens.is_empty() { + return Ok(false); + } + } + _ => return Ok(false), + }; + + // Check the first attribute tag (the first "embed") + if !self.path().is_ident(EMBED) { + return Ok(false); + } + + self.parse_nested_meta(|meta| { + // Parse the meta attribute as an expression. Need this to compile. + meta.value()?.parse::<syn::Expr>()?; + + if meta.path.is_ident(EMBED_WITH) { + Ok(()) + } else { + let path = meta.path.to_token_stream().to_string().replace(' ', ""); + Err(syn::Error::new_spanned( + meta.path, + format_args!("unknown embedding field attribute `{}`", path), + )) + } + })?; + + Ok(true) + } + + fn expand_tag(&self) -> syn::Result<syn::ExprPath> { + let mut custom_func_path = None; + + self.parse_nested_meta(|meta| match function_path(&meta) { + Ok(path) => { + custom_func_path = Some(path); + Ok(()) + } + Err(e) => Err(e), + })?; + + Ok(custom_func_path.unwrap()) + } +} + +fn function_path(meta: &ParseNestedMeta<'_>) -> syn::Result<ExprPath> { + // #[embed(embed_with = "...")] + let expr = meta.value()?.parse::<syn::Expr>().unwrap(); + let mut value = &expr; + while let syn::Expr::Group(e) = value { + value = &e.expr; + } + let string = if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit_str), + .. + }) = value + { + let suffix = lit_str.suffix(); + if !suffix.is_empty() { + return Err(syn::Error::new( + lit_str.span(), + format!("unexpected suffix `{}` on string literal", suffix), + )); + } + lit_str.clone() + } else { + return Err(syn::Error::new( + value.span(), + format!( + "expected {} attribute to be a string: `{} = \"...\"`", + EMBED_WITH, EMBED_WITH + ), + )); + }; + + string.parse() +} diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index 8546f84a..cff035d1 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -1,55 +1,49 @@ -use proc_macro::TokenStream; +use proc_macro2::TokenStream; use quote::quote; -use syn::{ - meta::ParseNestedMeta, parse_quote, parse_str, punctuated::Punctuated, spanned::Spanned, - Attribute, DataStruct, ExprPath, Meta, Token, -}; +use syn::{parse_quote, parse_str, Attribute, DataStruct, Meta}; -const EMBED: &str = "embed"; -const EMBED_WITH: &str = "embed_with"; +use crate::{custom::CustomAttributeParser, EMBED}; const VEC_TYPE: &str = "Vec"; +const MANY_EMBEDDING: &str = "ManyEmbedding"; +const SINGLE_EMBEDDING: &str = "SingleEmbedding"; -pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { +pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result<TokenStream> { let name = &input.ident; - let (func_calls, embed_kind) = - match &input.data { - syn::Data::Struct(data_struct) => { - // Handles fields tagged with #[embed] - let mut function_calls = data_struct - .basic_embed_fields() - .map(|field| { - add_struct_bounds(&mut input.generics, &field.ty); + let (embed_targets, embed_kind) = match &input.data { + syn::Data::Struct(data_struct) => { + // Handles fields tagged with #[embed] + let mut inner_embed_targets = data_struct + .basic_embed_fields() + .map(|field| { + add_struct_bounds(&mut input.generics, &field.ty); - let field_name = field.ident; + let field_name = field.ident; - quote! { - self.#field_name.embeddable() - } - }) - .collect::<Vec<_>>(); + quote! { + self.#field_name.embeddable() + } + }) + .collect::<Vec<_>>(); - // Handles fields tagged with #[embed(embed_with = "...")] - function_calls.extend(data_struct.custom_embed_fields().unwrap().map( - |(field, _)| { - let field_name = field.ident; + // Handles fields tagged with #[embed(embed_with = "...")] + inner_embed_targets.extend(data_struct.custom_embed_fields()?.map(|(field, _)| { + let field_name = field.ident; - quote! { - embeddable(&self.#field_name) - } - }, - )); + quote! { + embeddable(&self.#field_name) + } + })); - (function_calls, data_struct.embed_kind().unwrap()) - } - _ => panic!("Embeddable can only be derived for structs"), - }; + (inner_embed_targets, data_struct.embed_kind()?) + } + _ => panic!("Embeddable trait can only be derived for structs"), + }; // Import the paths to the custom functions. let custom_func_paths = match &input.data { syn::Data::Struct(data_struct) => data_struct - .custom_embed_fields() - .unwrap() + .custom_embed_fields()? .map(|(_, custom_func_path)| { quote! { use #custom_func_path::embeddable; @@ -59,11 +53,17 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { _ => vec![], }; + // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. + // ie. do not implement Embeddable trait for the struct. + if embed_targets.is_empty() { + return Ok(TokenStream::new()); + } + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { - use rig::embeddings::Embeddable; - use rig::embeddings::#embed_kind; + // Note: we do NOT import the Embeddable trait here because if there are multiple structs in the same file + // that derive Embed, there will be import conflicts. #(#custom_func_paths);* @@ -72,14 +72,14 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { fn embeddable(&self) -> Vec<String> { vec![ - #(#func_calls),* + #(#embed_targets),* ].into_iter().flatten().collect() } } }; eprintln!("Generated code:\n{}", gen); - gen.into() + Ok(gen) } // Adds bounds to where clause that force all fields tagged with #[embed] to implement the Embeddable trait. @@ -91,36 +91,35 @@ fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { }); } -fn embed_kind(field: &syn::Field) -> Result<syn::Expr, syn::Error> { +fn embed_kind(field: &syn::Field) -> syn::Result<syn::Expr> { match &field.ty { syn::Type::Path(path) => { if path.path.segments.first().unwrap().ident == VEC_TYPE { - parse_str("ManyEmbedding") + parse_str(MANY_EMBEDDING) } else { - parse_str("SingleEmbedding") + parse_str(SINGLE_EMBEDDING) } } - _ => parse_str("SingleEmbedding"), + _ => parse_str(SINGLE_EMBEDDING), } } -trait AttributeParser { +trait StructParser { /// Finds and returns fields with simple #[embed] attribute tags only. fn basic_embed_fields(&self) -> impl Iterator<Item = syn::Field>; /// Finds and returns fields with #[embed(embed_with = "...")] attribute tags only. /// Also returns the attribute in question. - fn custom_embed_fields( - &self, - ) -> Result<impl Iterator<Item = (syn::Field, syn::ExprPath)>, syn::Error>; + fn custom_embed_fields(&self) + -> syn::Result<impl Iterator<Item = (syn::Field, syn::ExprPath)>>; /// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is 1, /// returns the kind of embedding that field should be. /// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is greater than 1, /// return ManyEmbedding. - fn embed_kind(&self) -> Result<syn::Expr, syn::Error> { + fn embed_kind(&self) -> syn::Result<syn::Expr> { let fields = self .basic_embed_fields() - .chain(self.custom_embed_fields().unwrap().map(|(f, _)| f)) + .chain(self.custom_embed_fields()?.map(|(f, _)| f)) .collect::<Vec<_>>(); if fields.len() == 1 { @@ -131,7 +130,7 @@ trait AttributeParser { } } -impl AttributeParser for DataStruct { +impl StructParser for DataStruct { fn basic_embed_fields(&self) -> impl Iterator<Item = syn::Field> { self.fields.clone().into_iter().filter(|field| { field @@ -150,42 +149,7 @@ impl AttributeParser for DataStruct { fn custom_embed_fields( &self, - ) -> Result<impl Iterator<Item = (syn::Field, syn::ExprPath)>, syn::Error> { - // Determine if field is tagged with #[embed(embed_with = "...")] attribute. - fn is_custom_embed(attribute: &syn::Attribute) -> Result<bool, syn::Error> { - let is_custom_embed = match attribute.meta { - Meta::List(_) => attribute - .parse_args_with(Punctuated::<Meta, Token![=]>::parse_terminated)? - .into_iter() - .any(|meta| meta.path().is_ident(EMBED_WITH)), - _ => false, - }; - - Ok(attribute.path().is_ident(EMBED) && is_custom_embed) - } - - // Get the "..." part of the #[embed(embed_with = "...")] attribute. - // Ex: If attribute is tagged with #[embed(embed_with = "my_embed")], returns "my_embed". - fn expand_tag(attribute: &syn::Attribute) -> Result<syn::ExprPath, syn::Error> { - let mut custom_func_path = None; - - attribute.parse_nested_meta(|meta| { - custom_func_path = Some(meta.function_path()?); - Ok(()) - })?; - - match custom_func_path { - Some(path) => Ok(path), - None => Err(syn::Error::new( - attribute.span(), - format!( - "expected {} attribute to have format: `#[embed(embed_with = \"...\")]`", - EMBED_WITH - ), - )), - } - } - + ) -> syn::Result<impl Iterator<Item = (syn::Field, syn::ExprPath)>> { Ok(self .fields .clone() @@ -196,8 +160,8 @@ impl AttributeParser for DataStruct { .clone() .into_iter() .map(|attribute| { - if is_custom_embed(&attribute)? { - Ok::<_, syn::Error>(Some((field.clone(), expand_tag(&attribute)?))) + if attribute.is_custom()? { + Ok::<_, syn::Error>(Some((field.clone(), attribute.expand_tag()?))) } else { Ok(None) } @@ -210,42 +174,3 @@ impl AttributeParser for DataStruct { .flatten()) } } - -trait CustomFunction { - fn function_path(&self) -> Result<ExprPath, syn::Error>; -} - -impl CustomFunction for ParseNestedMeta<'_> { - fn function_path(&self) -> Result<ExprPath, syn::Error> { - // #[embed(embed_with = "...")] - let expr = self.value().unwrap().parse::<syn::Expr>().unwrap(); - let mut value = &expr; - while let syn::Expr::Group(e) = value { - value = &e.expr; - } - let string = if let syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(lit_str), - .. - }) = value - { - let suffix = lit_str.suffix(); - if !suffix.is_empty() { - return Err(syn::Error::new( - lit_str.span(), - format!("unexpected suffix `{}` on string literal", suffix), - )); - } - lit_str.clone() - } else { - return Err(syn::Error::new( - value.span(), - format!( - "expected {} attribute to be a string: `{} = \"...\"`", - EMBED_WITH, EMBED_WITH - ), - )); - }; - - string.parse() - } -} diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index 19f6a845..ad07592e 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -2,8 +2,11 @@ extern crate proc_macro; use proc_macro::TokenStream; use syn::{parse_macro_input, DeriveInput}; +mod custom; mod embeddable; +pub(crate) const EMBED: &str = "embed"; + // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html @@ -12,4 +15,6 @@ pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); embeddable::expand_derive_embedding(&mut input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index d68ae565..8e943982 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -240,7 +240,7 @@ impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> // Generate the embeddings for a chunk at a time. .map(|docs| async { let (document_indices, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - + Ok::<_, EmbeddingError>( document_indices .into_iter() @@ -403,3 +403,31 @@ impl<T: Embeddable> Embeddable for Vec<T> { self.iter().flat_map(|i| i.embeddable()).collect() } } + +#[cfg(test)] +mod tests { + use super::{Embeddable, SingleEmbedding}; + + use rig_derive::Embed; + use serde::Serialize; + + // #[derive(Serialize)] + // struct FakeDefinition2 { + // id: String, + // #[serde(test = "")] + // definition: String, + // } + + #[derive(Embed)] + struct FakeDefinition { + id: String, + #[embed(something = "a")] + definition: String, + } + + #[test] + fn test_missing_embed_fields() {} + + #[test] + fn test_empty_custom_function() {} +} diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index e8ebead9..747e23a8 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; -use rig::embeddings::Embedding; +use rig::embeddings::{Embedding, Embeddable, SingleEmbedding}; use rig_derive::Embed; use serde::Deserialize; diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index d4593b47..f740a29d 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use std::env; use rig::{ - embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, + embeddings::{EmbeddingsBuilder, Embeddable, SingleEmbedding}, providers::openai::Client, vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; From 5684c90e9718d7d9077eb5912e40b7fb89355b63 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 10 Oct 2024 15:25:52 -0400 Subject: [PATCH 18/91] tests: add unit tests on in memory store --- rig-core/src/vector_store/in_memory_store.rs | 148 ++++++++++++++++++- 1 file changed, 143 insertions(+), 5 deletions(-) diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 1f71e9ed..40ad018e 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -84,11 +84,6 @@ impl<D: Serialize + Eq> InMemoryVectorStore<D> { .map(|(doc, _)| serde_json::from_str(&serde_json::to_string(doc)?)) .transpose()?) } - - /// Get the document embeddings by its id. - pub fn get_document_embeddings(&self, id: &str) -> Result<Option<&D>, VectorStoreError> { - Ok(self.embeddings.get(id).map(|(doc, _)| doc)) - } } /// RankingItem(distance, document_id, serializable document, embeddings document) @@ -189,3 +184,146 @@ impl<M: EmbeddingModel + std::marker::Sync, D: Serialize + Sync + Send + Eq> Vec .collect::<Result<Vec<_>, _>>() } } + +#[cfg(test)] +mod tests { + use std::cmp::Reverse; + + use crate::embeddings::Embedding; + + use super::{InMemoryVectorStore, RankingItem}; + + #[test] + fn test_single_embedding() { + let index = InMemoryVectorStore::default() + .add_documents(vec![ + ( + "doc1".to_string(), + "glarb-garb", + vec![Embedding { + document: "glarb-garb".to_string(), + vec: vec![0.1, 0.1, 0.5], + }], + ), + ( + "doc2".to_string(), + "marble-marble", + vec![Embedding { + document: "marble-marble".to_string(), + vec: vec![0.7, -0.3, 0.0], + }], + ), + ( + "doc3".to_string(), + "flumb-flumb", + vec![Embedding { + document: "flumb-flumb".to_string(), + vec: vec![0.3, 0.7, 0.1], + }], + ), + ]) + .unwrap(); + + let ranking = index.vector_search( + &Embedding { + document: "glarby-glarble".to_string(), + vec: vec![0.0, 0.1, 0.6], + }, + 1, + ); + + assert_eq!( + ranking + .into_iter() + .map(|Reverse(RankingItem(distance, id, doc, _))| { + ( + distance.0, + id.clone(), + serde_json::from_str(&serde_json::to_string(doc).unwrap()).unwrap(), + ) + }) + .collect::<Vec<(_, _, String)>>(), + vec![( + 0.034444444444444444, + "doc1".to_string(), + "glarb-garb".to_string() + )] + ) + } + + #[test] + fn test_multiple_embeddings() { + let index = InMemoryVectorStore::default() + .add_documents(vec![ + ( + "doc1".to_string(), + "glarb-garb", + vec![ + Embedding { + document: "glarb-garb".to_string(), + vec: vec![0.1, 0.1, 0.5], + }, + Embedding { + document: "don't-choose-me".to_string(), + vec: vec![-0.5, 0.9, 0.1], + }, + ], + ), + ( + "doc2".to_string(), + "marble-marble", + vec![ + Embedding { + document: "marble-marble".to_string(), + vec: vec![0.7, -0.3, 0.0], + }, + Embedding { + document: "sandwich".to_string(), + vec: vec![0.5, 0.5, -0.7], + }, + ], + ), + ( + "doc3".to_string(), + "flumb-flumb", + vec![ + Embedding { + document: "flumb-flumb".to_string(), + vec: vec![0.3, 0.7, 0.1], + }, + Embedding { + document: "banana".to_string(), + vec: vec![0.1, -0.5, -0.5], + }, + ], + ), + ]) + .unwrap(); + + let ranking = index.vector_search( + &Embedding { + document: "glarby-glarble".to_string(), + vec: vec![0.0, 0.1, 0.6], + }, + 1, + ); + + assert_eq!( + ranking + .into_iter() + .map(|Reverse(RankingItem(distance, id, doc, _))| { + ( + distance.0, + id.clone(), + serde_json::from_str(&serde_json::to_string(doc).unwrap()).unwrap(), + ) + }) + .collect::<Vec<(_, _, String)>>(), + vec![( + 0.034444444444444444, + "doc1".to_string(), + "glarb-garb".to_string() + )] + ) + } +} From 82d9f0c879d863f19e58fb2215a42e1cc613d066 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 10 Oct 2024 15:31:55 -0400 Subject: [PATCH 19/91] fic(ci): asterix on pull request sto accomodate for epic branches --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e00d68cd..ac7831c2 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -4,7 +4,7 @@ name: Lint & Test on: pull_request: branches: - - main + - "*" push: branches: - main From bf7316b219ea870986478fd5db9978c4e041afa1 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 10 Oct 2024 15:37:56 -0400 Subject: [PATCH 20/91] fix(ci): double asterix --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ac7831c2..cf3409ff 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -4,7 +4,7 @@ name: Lint & Test on: pull_request: branches: - - "*" + - "**" push: branches: - main From 83251640c1ef23ee97c5e7336fd690589075d284 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 10 Oct 2024 20:08:45 -0400 Subject: [PATCH 21/91] feat: add error type on embeddable trait --- rig-core/rig-core-derive/src/custom.rs | 68 ++++++------- rig-core/rig-core-derive/src/embeddable.rs | 54 ++++++---- rig-core/src/embeddings.rs | 99 +++++++++++-------- .../examples/vector_search_local_ann.rs | 2 +- 4 files changed, 127 insertions(+), 96 deletions(-) diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs index eb6478bd..77f321f6 100644 --- a/rig-core/rig-core-derive/src/custom.rs +++ b/rig-core/rig-core-derive/src/custom.rs @@ -1,5 +1,5 @@ use quote::ToTokens; -use syn::{meta::ParseNestedMeta, spanned::Spanned, ExprPath}; +use syn::{meta::ParseNestedMeta, ExprPath}; use crate::EMBED; @@ -50,6 +50,39 @@ impl CustomAttributeParser for syn::Attribute { } fn expand_tag(&self) -> syn::Result<syn::ExprPath> { + fn function_path(meta: &ParseNestedMeta<'_>) -> syn::Result<ExprPath> { + // #[embed(embed_with = "...")] + let expr = meta.value()?.parse::<syn::Expr>().unwrap(); + let mut value = &expr; + while let syn::Expr::Group(e) = value { + value = &e.expr; + } + let string = if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit_str), + .. + }) = value + { + let suffix = lit_str.suffix(); + if !suffix.is_empty() { + return Err(syn::Error::new_spanned( + lit_str, + format!("unexpected suffix `{}` on string literal", suffix), + )); + } + lit_str.clone() + } else { + return Err(syn::Error::new_spanned( + value, + format!( + "expected {} attribute to be a string: `{} = \"...\"`", + EMBED_WITH, EMBED_WITH + ), + )); + }; + + string.parse() + } + let mut custom_func_path = None; self.parse_nested_meta(|meta| match function_path(&meta) { @@ -63,36 +96,3 @@ impl CustomAttributeParser for syn::Attribute { Ok(custom_func_path.unwrap()) } } - -fn function_path(meta: &ParseNestedMeta<'_>) -> syn::Result<ExprPath> { - // #[embed(embed_with = "...")] - let expr = meta.value()?.parse::<syn::Expr>().unwrap(); - let mut value = &expr; - while let syn::Expr::Group(e) = value { - value = &e.expr; - } - let string = if let syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(lit_str), - .. - }) = value - { - let suffix = lit_str.suffix(); - if !suffix.is_empty() { - return Err(syn::Error::new( - lit_str.span(), - format!("unexpected suffix `{}` on string literal", suffix), - )); - } - lit_str.clone() - } else { - return Err(syn::Error::new( - value.span(), - format!( - "expected {} attribute to be a string: `{} = \"...\"`", - EMBED_WITH, EMBED_WITH - ), - )); - }; - - string.parse() -} diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index cff035d1..f5c6e781 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -10,10 +10,10 @@ const SINGLE_EMBEDDING: &str = "SingleEmbedding"; pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result<TokenStream> { let name = &input.ident; - let (embed_targets, embed_kind) = match &input.data { + let (embed_targets, custom_embed_targets, embed_kind) = match &input.data { syn::Data::Struct(data_struct) => { // Handles fields tagged with #[embed] - let mut inner_embed_targets = data_struct + let embed_targets = data_struct .basic_embed_fields() .map(|field| { add_struct_bounds(&mut input.generics, &field.ty); @@ -21,25 +21,38 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result<Toke let field_name = field.ident; quote! { - self.#field_name.embeddable() + self.#field_name } }) .collect::<Vec<_>>(); // Handles fields tagged with #[embed(embed_with = "...")] - inner_embed_targets.extend(data_struct.custom_embed_fields()?.map(|(field, _)| { - let field_name = field.ident; + let custom_embed_targets = data_struct + .custom_embed_fields()? + .map(|(field, _)| { + let field_name = field.ident; - quote! { - embeddable(&self.#field_name) - } - })); + quote! { + self.#field_name + } + }) + .collect::<Vec<_>>(); - (inner_embed_targets, data_struct.embed_kind()?) + ( + embed_targets, + custom_embed_targets, + data_struct.embed_kind()?, + ) } _ => panic!("Embeddable trait can only be derived for structs"), }; + // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. + // ie. do not implement Embeddable trait for the struct. + if embed_targets.is_empty() && custom_embed_targets.is_empty() { + return Ok(TokenStream::new()); + } + // Import the paths to the custom functions. let custom_func_paths = match &input.data { syn::Data::Struct(data_struct) => data_struct @@ -53,12 +66,6 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result<Toke _ => vec![], }; - // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. - // ie. do not implement Embeddable trait for the struct. - if embed_targets.is_empty() { - return Ok(TokenStream::new()); - } - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { @@ -69,11 +76,18 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result<Toke impl #impl_generics Embeddable for #name #ty_generics #where_clause { type Kind = #embed_kind; + type Error = EmbeddingGenerationError; + + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + vec![#(#embed_targets.clone()),*].embeddable() + + // let custom_embed_targets = vec![#( embeddable( #embed_targets ); ),*] + // .iter() + // .collect::<Result<Vec<_>, _>>()? + // .into_iter() + // .flatten(); - fn embeddable(&self) -> Vec<String> { - vec![ - #(#embed_targets),* - ].into_iter().flatten().collect() + // Ok(embed_targets.chain(custom_embed_targets).collect()) } } }; diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index 8e943982..7e25a69a 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -172,9 +172,17 @@ impl EmbeddingKind for SingleEmbedding {} pub struct ManyEmbedding; impl EmbeddingKind for ManyEmbedding {} +#[derive(Debug, thiserror::Error)] +pub enum EmbeddingGenerationError { + #[error("SerdeError: {0}")] + SerdeError(#[from] serde_json::Error), +} + pub trait Embeddable { type Kind: EmbeddingKind; - fn embeddable(&self) -> Vec<String>; + type Error: std::error::Error; + + fn embeddable(&self) -> Result<Vec<String>, Self::Error>; } /// Builder for creating a collection of embeddings @@ -195,22 +203,22 @@ impl<M: EmbeddingModel, D: Embeddable<Kind = K>, K: EmbeddingKind> EmbeddingsBui } /// Add a document that implements `Embeddable` to the builder. - pub fn document(mut self, document: D) -> Self { - let embed_targets = document.embeddable(); + pub fn document(mut self, document: D) -> Result<Self, D::Error> { + let embed_targets = document.embeddable()?; self.documents.push((document, embed_targets)); - self + Ok(self) } /// Add many documents that implement `Embeddable` to the builder. - pub fn documents(mut self, documents: Vec<D>) -> EmbeddingsBuilder<M, D, D::Kind> { - documents.into_iter().for_each(|doc| { - let embed_targets = doc.embeddable(); + pub fn documents(mut self, documents: Vec<D>) -> Result<Self, D::Error> { + for doc in documents.into_iter() { + let embed_targets = doc.embeddable()?; self.documents.push((doc, embed_targets)); - }); + } - self + Ok(self) } } @@ -318,110 +326,119 @@ impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> ////////////////////////////////////////////////////// impl Embeddable for String { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec<String> { - vec![self.clone()] + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + Ok(vec![self.clone()]) } } impl Embeddable for i8 { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec<String> { - vec![self.to_string()] + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for i16 { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec<String> { - vec![self.to_string()] + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for i32 { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec<String> { - vec![self.to_string()] + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for i64 { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec<String> { - vec![self.to_string()] + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for i128 { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec<String> { - vec![self.to_string()] + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for f32 { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec<String> { - vec![self.to_string()] + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for f64 { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec<String> { - vec![self.to_string()] + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for bool { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec<String> { - vec![self.to_string()] + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for char { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec<String> { - vec![self.to_string()] + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + Ok(vec![self.to_string()]) } } impl<T: Embeddable> Embeddable for Vec<T> { type Kind = ManyEmbedding; + type Error = T::Error; - fn embeddable(&self) -> Vec<String> { - self.iter().flat_map(|i| i.embeddable()).collect() + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + Ok(self + .iter() + .map(|i| i.embeddable()) + .collect::<Result<Vec<_>, _>>()? + .into_iter() + .flatten() + .collect()) } } #[cfg(test)] mod tests { - use super::{Embeddable, SingleEmbedding}; + use super::{Embeddable, SingleEmbedding, EmbeddingGenerationError}; use rig_derive::Embed; - use serde::Serialize; - - // #[derive(Serialize)] - // struct FakeDefinition2 { - // id: String, - // #[serde(test = "")] - // definition: String, - // } #[derive(Embed)] struct FakeDefinition { id: String, - #[embed(something = "a")] + #[embed] definition: String, } diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index af82ad49..7a9c8cc1 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -27,7 +27,7 @@ async fn main() -> Result<(), anyhow::Error> { // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions()) + .documents(fake_definitions())? // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) From de022c49c5030f5dc9f0ef8cb861f8ee7c9a577c Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 11 Oct 2024 11:09:41 -0400 Subject: [PATCH 22/91] refactor: move embeddings to its own module and seperate embeddable --- rig-core/examples/rag.rs | 8 +- rig-core/examples/vector_search.rs | 8 +- rig-core/examples/vector_search_cohere.rs | 8 +- rig-core/rig-core-derive/src/lib.rs | 2 +- .../embeddable.rs} | 130 +++--------------- rig-core/src/embeddings/embedding.rs | 99 +++++++++++++ rig-core/src/embeddings/mod.rs | 7 + rig-core/src/lib.rs | 4 + rig-core/src/providers/cohere.rs | 12 +- rig-core/src/providers/openai.rs | 18 ++- rig-core/src/vector_store/in_memory_store.rs | 2 +- rig-core/src/vector_store/mod.rs | 2 +- rig-lancedb/examples/fixtures/lib.rs | 9 +- .../examples/vector_search_local_ann.rs | 4 +- .../examples/vector_search_local_enn.rs | 4 +- rig-lancedb/examples/vector_search_s3_ann.rs | 6 +- rig-lancedb/src/lib.rs | 2 +- rig-mongodb/examples/vector_search_mongodb.rs | 10 +- rig-mongodb/src/lib.rs | 2 +- 19 files changed, 182 insertions(+), 155 deletions(-) rename rig-core/src/{embeddings.rs => embeddings/embeddable.rs} (73%) create mode 100644 rig-core/src/embeddings/embedding.rs create mode 100644 rig-core/src/embeddings/mod.rs diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 82e9d4fd..9ff579b0 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -2,16 +2,16 @@ use std::{env, vec}; use rig::{ completion::Prompt, - embeddings::{Embeddable, EmbeddingsBuilder, ManyEmbedding}, + embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, ManyEmbedding}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, + Embeddable, }; -use rig_derive::Embed; use serde::Serialize; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embed, Clone, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(Embeddable, Clone, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, #[embed] @@ -49,7 +49,7 @@ async fn main() -> Result<(), anyhow::Error> { "Definition of a *linglingdong*: A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() ] }, - ]) + ])? .build() .await?; diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 523719ac..fe79a3f1 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,16 +1,16 @@ use std::env; use rig::{ - embeddings::{Embeddable, EmbeddingsBuilder, ManyEmbedding}, + embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, ManyEmbedding}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, + Embeddable, }; -use rig_derive::Embed; use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, word: String, @@ -52,7 +52,7 @@ async fn main() -> Result<(), anyhow::Error> { "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() ] }, - ]) + ])? .build() .await?; diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index f9f84175..9432df39 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,16 +1,16 @@ use std::env; use rig::{ - embeddings::{Embeddable, EmbeddingsBuilder, ManyEmbedding}, + embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, ManyEmbedding}, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, + Embeddable, }; -use rig_derive::Embed; use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, word: String, @@ -53,7 +53,7 @@ async fn main() -> Result<(), anyhow::Error> { "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() ] }, - ]) + ])? .build() .await?; diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index ad07592e..fba92f1c 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -10,7 +10,7 @@ pub(crate) const EMBED: &str = "embed"; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html -#[proc_macro_derive(Embed, attributes(embed))] +#[proc_macro_derive(Embeddable, attributes(embed))] pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings/embeddable.rs similarity index 73% rename from rig-core/src/embeddings.rs rename to rig-core/src/embeddings/embeddable.rs index 7e25a69a..08e8b316 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -1,20 +1,6 @@ -//! This module provides functionality for working with embeddings and embedding models. -//! Embeddings are numerical representations of documents or other objects, typically used in -//! natural language processing (NLP) tasks such as text classification, information retrieval, -//! and document similarity. -//! -//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can -//! generate embeddings for documents. It also provides an implementation of the [EmbeddingsBuilder] -//! struct, which allows users to build collections of document embeddings using different embedding -//! models and document sources. -//! -//! The module also defines the [Embedding] struct, which represents a single document embedding. -//! -//! The module also defines the [Embeddable] trait, which represents types that can be embedded. -//! Only types that implement the Embeddable trait can be used with the EmbeddingsBuilder. -//! -//! Finally, the module defines the [EmbeddingError] enum, which represents various errors that -//! can occur during embedding generation or processing. +//! The module defines the [Embeddable] trait, which must be implemented for types that can be embedded. +//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. +//! Only types that implement the [Embeddable] trait can be added to the [EmbeddingsBuilder]. //! //! # Example //! ```rust @@ -73,102 +59,18 @@ //! // ... //! ``` -use std::{cmp::max, collections::HashMap, marker::PhantomData}; - +use super::embedding::{Embedding, EmbeddingError, EmbeddingModel}; use futures::{stream, StreamExt, TryStreamExt}; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, thiserror::Error)] -pub enum EmbeddingError { - /// Http error (e.g.: connection error, timeout, etc.) - #[error("HttpError: {0}")] - HttpError(#[from] reqwest::Error), - - /// Json error (e.g.: serialization, deserialization) - #[error("JsonError: {0}")] - JsonError(#[from] serde_json::Error), - - /// Error processing the document for embedding - #[error("DocumentError: {0}")] - DocumentError(String), - - /// Error parsing the completion response - #[error("ResponseError: {0}")] - ResponseError(String), - - /// Error returned by the embedding model provider - #[error("ProviderError: {0}")] - ProviderError(String), -} - -/// Trait for embedding models that can generate embeddings for documents. -pub trait EmbeddingModel: Clone + Sync + Send { - /// The maximum number of documents that can be embedded in a single request. - const MAX_DOCUMENTS: usize; - - /// The number of dimensions in the embedding vector. - fn ndims(&self) -> usize; - - /// Embed a single document - fn embed_document( - &self, - document: &str, - ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send - where - Self: Sync, - { - async { - Ok(self - .embed_documents(vec![document.to_string()]) - .await? - .first() - .cloned() - .expect("One embedding should be present")) - } - } - - /// Embed multiple documents in a single request - fn embed_documents( - &self, - documents: Vec<String>, - ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send; -} - -/// Struct that holds a single document and its embedding. -#[derive(Clone, Default, Deserialize, Serialize, Debug)] -pub struct Embedding { - /// The document that was embedded. Used for debugging. - pub document: String, - /// The embedding vector - pub vec: Vec<f64>, -} - -impl PartialEq for Embedding { - fn eq(&self, other: &Self) -> bool { - self.document == other.document - } -} - -impl Eq for Embedding {} - -impl Embedding { - pub fn distance(&self, other: &Self) -> f64 { - let dot_product: f64 = self - .vec - .iter() - .zip(other.vec.iter()) - .map(|(x, y)| x * y) - .sum(); - - let product_of_lengths = (self.vec.len() * other.vec.len()) as f64; - - dot_product / product_of_lengths - } -} +use std::{cmp::max, collections::HashMap, marker::PhantomData}; +/// The associated type `Kind` on the trait `Embeddable` must implement this trait. pub trait EmbeddingKind {} + +/// Used for structs that contain a single embedding target. pub struct SingleEmbedding; impl EmbeddingKind for SingleEmbedding {} + +/// Used for structs that contain many embedding targets. pub struct ManyEmbedding; impl EmbeddingKind for ManyEmbedding {} @@ -178,6 +80,9 @@ pub enum EmbeddingGenerationError { SerdeError(#[from] serde_json::Error), } +/// Trait for types that can be embedded. +/// The `embeddable` method returns a list of strings for which embeddings will be generated by the embeddings builder. +/// If the type `Kind` is `SingleEmbedding`, the list of strings contains a single item, otherwise, the list can contain many items. pub trait Embeddable { type Kind: EmbeddingKind; type Error: std::error::Error; @@ -185,7 +90,7 @@ pub trait Embeddable { fn embeddable(&self) -> Result<Vec<String>, Self::Error>; } -/// Builder for creating a collection of embeddings +/// Builder for creating a collection of embeddings. pub struct EmbeddingsBuilder<M: EmbeddingModel, D: Embeddable, K: EmbeddingKind> { kind: PhantomData<K>, model: M, @@ -431,11 +336,10 @@ impl<T: Embeddable> Embeddable for Vec<T> { #[cfg(test)] mod tests { - use super::{Embeddable, SingleEmbedding, EmbeddingGenerationError}; - - use rig_derive::Embed; + use super::{Embeddable, EmbeddingGenerationError, SingleEmbedding}; + use rig_derive::Embeddable; - #[derive(Embed)] + #[derive(Embeddable)] struct FakeDefinition { id: String, #[embed] diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs new file mode 100644 index 00000000..ff284a05 --- /dev/null +++ b/rig-core/src/embeddings/embedding.rs @@ -0,0 +1,99 @@ +//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can +//! generate embeddings for documents. It also provides an implementation of the [EmbeddingsBuilder] +//! struct, which allows users to build collections of document embeddings using different embedding +//! models and document sources. +//! +//! The module also defines the [Embedding] struct, which represents a single document embedding. +//! +//! Finally, the module defines the [EmbeddingError] enum, which represents various errors that +//! can occur during embedding generation or processing. + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, thiserror::Error)] +pub enum EmbeddingError { + /// Http error (e.g.: connection error, timeout, etc.) + #[error("HttpError: {0}")] + HttpError(#[from] reqwest::Error), + + /// Json error (e.g.: serialization, deserialization) + #[error("JsonError: {0}")] + JsonError(#[from] serde_json::Error), + + /// Error processing the document for embedding + #[error("DocumentError: {0}")] + DocumentError(String), + + /// Error parsing the completion response + #[error("ResponseError: {0}")] + ResponseError(String), + + /// Error returned by the embedding model provider + #[error("ProviderError: {0}")] + ProviderError(String), +} + +/// Trait for embedding models that can generate embeddings for documents. +pub trait EmbeddingModel: Clone + Sync + Send { + /// The maximum number of documents that can be embedded in a single request. + const MAX_DOCUMENTS: usize; + + /// The number of dimensions in the embedding vector. + fn ndims(&self) -> usize; + + /// Embed a single document + fn embed_document( + &self, + document: &str, + ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send + where + Self: Sync, + { + async { + Ok(self + .embed_documents(vec![document.to_string()]) + .await? + .first() + .cloned() + .expect("One embedding should be present")) + } + } + + /// Embed multiple documents in a single request + fn embed_documents( + &self, + documents: Vec<String>, + ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send; +} + +/// Struct that holds a single document and its embedding. +#[derive(Clone, Default, Deserialize, Serialize, Debug)] +pub struct Embedding { + /// The document that was embedded. Used for debugging. + pub document: String, + /// The embedding vector + pub vec: Vec<f64>, +} + +impl PartialEq for Embedding { + fn eq(&self, other: &Self) -> bool { + self.document == other.document + } +} + +impl Eq for Embedding {} + +impl Embedding { + pub fn distance(&self, other: &Self) -> f64 { + let dot_product: f64 = self + .vec + .iter() + .zip(other.vec.iter()) + .map(|(x, y)| x * y) + .sum(); + + let product_of_lengths = (self.vec.len() * other.vec.len()) as f64; + + dot_product / product_of_lengths + } +} diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs new file mode 100644 index 00000000..33526769 --- /dev/null +++ b/rig-core/src/embeddings/mod.rs @@ -0,0 +1,7 @@ +//! This module provides functionality for working with embeddings. +//! Embeddings are numerical representations of documents or other objects, typically used in +//! natural language processing (NLP) tasks such as text classification, information retrieval, +//! and document similarity. + +pub mod embeddable; +pub mod embedding; diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 86c25209..b7f0615e 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -75,3 +75,7 @@ pub mod json_utils; pub mod providers; pub mod tool; pub mod vector_store; + +// Export Embeddable trait and Embeddable together. +pub use embeddings::embeddable::Embeddable; +pub use rig_derive::Embeddable; diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index c93709c7..33464939 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -13,7 +13,11 @@ use std::collections::HashMap; use crate::{ agent::AgentBuilder, completion::{self, CompletionError}, - embeddings::{self, Embeddable, EmbeddingError, EmbeddingsBuilder}, + embeddings::{ + self, + embeddable::{Embeddable, EmbeddingsBuilder}, + embedding::EmbeddingError, + }, extractor::ExtractorBuilder, json_utils, }; @@ -187,7 +191,7 @@ pub struct EmbeddingModel { ndims: usize, } -impl embeddings::EmbeddingModel for EmbeddingModel { +impl embeddings::embedding::EmbeddingModel for EmbeddingModel { const MAX_DOCUMENTS: usize = 96; fn ndims(&self) -> usize { @@ -197,7 +201,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { async fn embed_documents( &self, documents: Vec<String>, - ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> { + ) -> Result<Vec<embeddings::embedding::Embedding>, EmbeddingError> { let response = self .client .post("/v1/embed") @@ -226,7 +230,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { .embeddings .into_iter() .zip(documents.into_iter()) - .map(|(embedding, document)| embeddings::Embedding { + .map(|(embedding, document)| embeddings::embedding::Embedding { document, vec: embedding, }) diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 6bd1711f..d23a46bf 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -11,7 +11,11 @@ use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, - embeddings::{self, Embeddable, EmbeddingError}, + embeddings::{ + self, + embeddable::{Embeddable, EmbeddingsBuilder}, + embedding::{Embedding, EmbeddingError}, + }, extractor::ExtractorBuilder, json_utils, }; @@ -121,11 +125,11 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings<D: Embeddable>( + pub fn embeddings<T: Embeddable>( &self, model: &str, - ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D, D::Kind> { - embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) + ) -> EmbeddingsBuilder<EmbeddingModel, T, T::Kind> { + EmbeddingsBuilder::new(self.embedding_model(model)) } /// Create a completion model with the given name. @@ -235,7 +239,7 @@ pub struct EmbeddingModel { ndims: usize, } -impl embeddings::EmbeddingModel for EmbeddingModel { +impl embeddings::embedding::EmbeddingModel for EmbeddingModel { const MAX_DOCUMENTS: usize = 1024; fn ndims(&self) -> usize { @@ -245,7 +249,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { async fn embed_documents( &self, documents: Vec<String>, - ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> { + ) -> Result<Vec<Embedding>, EmbeddingError> { let response = self .client .post("/v1/embeddings") @@ -271,7 +275,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { .data .into_iter() .zip(documents.into_iter()) - .map(|(embedding, document)| embeddings::Embedding { + .map(|(embedding, document)| Embedding { document, vec: embedding.embedding, }) diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index ba2e754b..e8d8bb70 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -8,7 +8,7 @@ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; use super::{VectorStoreError, VectorStoreIndex}; -use crate::embeddings::{Embedding, EmbeddingModel}; +use crate::embeddings::embedding::{Embedding, EmbeddingModel}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 396b5514..6f112b81 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -2,7 +2,7 @@ use futures::future::BoxFuture; use serde::Deserialize; use serde_json::Value; -use crate::embeddings::EmbeddingError; +use crate::embeddings::embedding::EmbeddingError; pub mod in_memory_store; diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 747e23a8..bf7f2a33 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -2,11 +2,14 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; -use rig::embeddings::{Embedding, Embeddable, SingleEmbedding}; -use rig_derive::Embed; +use rig::embeddings::{ + embeddable::{EmbeddingGenerationError, SingleEmbedding}, + embedding::Embedding, +}; +use rig::Embeddable; use serde::Deserialize; -#[derive(Embed, Clone, Deserialize, Debug)] +#[derive(Embeddable, Clone, Deserialize, Debug)] pub struct FakeDefinition { id: String, #[embed] diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 7a9c8cc1..e29a2b77 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -5,7 +5,7 @@ use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDe use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::{EmbeddingModel, EmbeddingsBuilder}, + embeddings::{embeddable::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; @@ -33,7 +33,7 @@ async fn main() -> Result<(), anyhow::Error> { (0..256) .map(|i| fake_definition(format!("doc{}", i))) .collect(), - ) + )? .build() .await?; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 1bf69481..7ae8c757 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -3,7 +3,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definitions, schema}; use rig::{ - embeddings::{EmbeddingModel, EmbeddingsBuilder}, + embeddings::{embeddable::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, }; @@ -23,7 +23,7 @@ async fn main() -> Result<(), anyhow::Error> { // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions()) + .documents(fake_definitions())? .build() .await?; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 17c4cd7f..0381429a 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,7 +4,7 @@ use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::{EmbeddingModel, EmbeddingsBuilder}, + embeddings::{embeddable::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; @@ -33,13 +33,13 @@ async fn main() -> Result<(), anyhow::Error> { // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions()) + .documents(fake_definitions())? // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) .map(|i| fake_definition(format!("doc{}", i))) .collect(), - ) + )? .build() .await?; diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 1e8b344a..edcc51e5 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -3,7 +3,7 @@ use lancedb::{ DistanceType, }; use rig::{ - embeddings::EmbeddingModel, + embeddings::embedding::EmbeddingModel, vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index f740a29d..dab94eb9 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,17 +1,19 @@ use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; use rig::providers::openai::TEXT_EMBEDDING_ADA_002; -use rig_derive::Embed; use serde::{Deserialize, Serialize}; use std::env; +use rig::Embeddable; use rig::{ - embeddings::{EmbeddingsBuilder, Embeddable, SingleEmbedding}, providers::openai::Client, vector_store::VectorStoreIndex, + embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, SingleEmbedding}, + providers::openai::Client, + vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embed, Clone, Deserialize, Debug)] +#[derive(Embeddable, Clone, Deserialize, Debug)] struct FakeDefinition { id: String, #[embed] @@ -67,7 +69,7 @@ async fn main() -> Result<(), anyhow::Error> { ]; let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions) + .documents(fake_definitions)? .build() .await?; diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 5ce33105..4778e454 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -2,7 +2,7 @@ use futures::StreamExt; use mongodb::bson::{self, doc}; use rig::{ - embeddings::{Embedding, EmbeddingModel}, + embeddings::embedding::{Embedding, EmbeddingModel}, vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; From 220d9fc392542e4a698e69d27d79c36dcff94d14 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 11 Oct 2024 17:01:31 -0400 Subject: [PATCH 23/91] refactor: split up macro into more files, fix all imports --- rig-core/examples/rag.rs | 2 +- rig-core/examples/vector_search.rs | 2 +- rig-core/examples/vector_search_cohere.rs | 2 +- rig-core/rig-core-derive/src/basic.rs | 29 ++ rig-core/rig-core-derive/src/custom.rs | 31 +- rig-core/rig-core-derive/src/embeddable.rs | 258 ++++++-------- rig-core/rig-core-derive/src/lib.rs | 1 + rig-core/src/embeddings/builder.rs | 206 +++++++++++ rig-core/src/embeddings/embeddable.rs | 335 +++++++----------- rig-core/src/embeddings/mod.rs | 1 + rig-core/src/providers/cohere.rs | 4 +- rig-core/src/providers/openai.rs | 3 +- rig-lancedb/examples/fixtures/lib.rs | 5 +- .../examples/vector_search_local_ann.rs | 2 +- .../examples/vector_search_local_enn.rs | 2 +- rig-lancedb/examples/vector_search_s3_ann.rs | 2 +- rig-mongodb/examples/vector_search_mongodb.rs | 28 +- 17 files changed, 528 insertions(+), 385 deletions(-) create mode 100644 rig-core/rig-core-derive/src/basic.rs create mode 100644 rig-core/src/embeddings/builder.rs diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 9ff579b0..43270a7b 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -2,7 +2,7 @@ use std::{env, vec}; use rig::{ completion::Prompt, - embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, ManyEmbedding}, + embeddings::builder::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, Embeddable, diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index fe79a3f1..a97ef8b0 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, ManyEmbedding}, + embeddings::builder::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, Embeddable, diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 9432df39..16ddb775 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, ManyEmbedding}, + embeddings::builder::EmbeddingsBuilder, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, Embeddable, diff --git a/rig-core/rig-core-derive/src/basic.rs b/rig-core/rig-core-derive/src/basic.rs new file mode 100644 index 00000000..7ac5bb47 --- /dev/null +++ b/rig-core/rig-core-derive/src/basic.rs @@ -0,0 +1,29 @@ +use syn::{parse_quote, Attribute, DataStruct, Meta}; + +use crate::EMBED; + +/// Finds and returns fields with simple #[embed] attribute tags only. +pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item = syn::Field> { + data_struct.fields.clone().into_iter().filter(|field| { + field + .attrs + .clone() + .into_iter() + .any(|attribute| match attribute { + Attribute { + meta: Meta::Path(path), + .. + } => path.is_ident(EMBED), + _ => false, + }) + }) +} + +// Adds bounds to where clause that force all fields tagged with #[embed] to implement the Embeddable trait. +pub(crate) fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { + let where_clause = generics.make_where_clause(); + + where_clause.predicates.push(parse_quote! { + #field_type: Embeddable + }); +} diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs index 77f321f6..8926aebf 100644 --- a/rig-core/rig-core-derive/src/custom.rs +++ b/rig-core/rig-core-derive/src/custom.rs @@ -5,7 +5,36 @@ use crate::EMBED; const EMBED_WITH: &str = "embed_with"; -pub(crate) trait CustomAttributeParser { +/// Finds and returns fields with #[embed(embed_with = "...")] attribute tags only. +/// Also returns the attribute in question. +pub(crate) fn custom_embed_fields( + data_struct: &syn::DataStruct, +) -> syn::Result<impl Iterator<Item = (syn::Field, syn::ExprPath)>> { + Ok(data_struct + .fields + .clone() + .into_iter() + .map(|field| { + field + .attrs + .clone() + .into_iter() + .map(|attribute| { + if attribute.is_custom()? { + Ok::<_, syn::Error>(Some((field.clone(), attribute.expand_tag()?))) + } else { + Ok(None) + } + }) + .collect::<Result<Vec<_>, _>>() + }) + .collect::<Result<Vec<_>, _>>()? + .into_iter() + .flatten() + .flatten()) +} + +trait CustomAttributeParser { // Determine if field is tagged with an #[embed(embed_with = "...")] attribute. fn is_custom(&self) -> syn::Result<bool>; diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index f5c6e781..914945ec 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -1,8 +1,11 @@ use proc_macro2::TokenStream; use quote::quote; -use syn::{parse_quote, parse_str, Attribute, DataStruct, Meta}; +use syn::{parse_str, DataStruct}; -use crate::{custom::CustomAttributeParser, EMBED}; +use crate::{ + basic::{add_struct_bounds, basic_embed_fields}, + custom::custom_embed_fields, +}; const VEC_TYPE: &str = "Vec"; const MANY_EMBEDDING: &str = "ManyEmbedding"; const SINGLE_EMBEDDING: &str = "SingleEmbedding"; @@ -10,41 +13,74 @@ const SINGLE_EMBEDDING: &str = "SingleEmbedding"; pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result<TokenStream> { let name = &input.ident; - let (embed_targets, custom_embed_targets, embed_kind) = match &input.data { - syn::Data::Struct(data_struct) => { - // Handles fields tagged with #[embed] - let embed_targets = data_struct - .basic_embed_fields() - .map(|field| { - add_struct_bounds(&mut input.generics, &field.ty); - - let field_name = field.ident; - - quote! { - self.#field_name - } - }) - .collect::<Vec<_>>(); - - // Handles fields tagged with #[embed(embed_with = "...")] - let custom_embed_targets = data_struct - .custom_embed_fields()? - .map(|(field, _)| { - let field_name = field.ident; - - quote! { - self.#field_name - } - }) - .collect::<Vec<_>>(); - - ( - embed_targets, - custom_embed_targets, - data_struct.embed_kind()?, - ) + // Handles fields tagged with #[embed] + let embed_targets = match &input.data { + syn::Data::Struct(data_struct) => basic_embed_fields(data_struct) + .map(|field| { + add_struct_bounds(&mut input.generics, &field.ty); + + let field_name = field.ident; + + quote! { + self.#field_name + } + }) + .collect::<Vec<_>>(), + _ => { + return Err(syn::Error::new_spanned( + name, + "Embeddable derive macro should only be used on structs", + )) + } + }; + + let embed_targets_quote = if !embed_targets.is_empty() { + quote! { + vec![#(#embed_targets.embeddable()),*] + .into_iter() + .collect::<Result<Vec<_>, _>>()? + .into_iter() + .flatten() + .collect::<Vec<_>>() + } + } else { + quote! { + vec![] + } + }; + + // Handles fields tagged with #[embed(embed_with = "...")] + let custom_embed_targets = match &input.data { + syn::Data::Struct(data_struct) => custom_embed_fields(data_struct)? + .map(|(field, custom_func_path)| { + let field_name = field.ident; + + quote! { + #custom_func_path(self.#field_name.clone()) + } + }) + .collect::<Vec<_>>(), + _ => { + return Err(syn::Error::new_spanned( + name, + "Embeddable derive macro should only be used on structs", + )) + } + }; + + let custom_embed_targets_quote = if !custom_embed_targets.is_empty() { + quote! { + vec![#(#custom_embed_targets),*] + .into_iter() + .collect::<Result<Vec<_>, _>>()? + .into_iter() + .flatten() + .collect::<Vec<_>>() + } + } else { + quote! { + vec![] } - _ => panic!("Embeddable trait can only be derived for structs"), }; // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. @@ -53,41 +89,34 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result<Toke return Ok(TokenStream::new()); } - // Import the paths to the custom functions. - let custom_func_paths = match &input.data { - syn::Data::Struct(data_struct) => data_struct - .custom_embed_fields()? - .map(|(_, custom_func_path)| { - quote! { - use #custom_func_path::embeddable; - } - }) - .collect::<Vec<_>>(), - _ => vec![], + // Determine whether the Embeddable::Kind should be SinleEmbedding or ManyEmbedding. + let embed_kind = match &input.data { + syn::Data::Struct(data_struct) => embed_kind(data_struct)?, + _ => { + return Err(syn::Error::new_spanned( + name, + "Embeddable derive macro should only be used on structs", + )) + } }; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { - // Note: we do NOT import the Embeddable trait here because if there are multiple structs in the same file - // that derive Embed, there will be import conflicts. - - #(#custom_func_paths);* + // Note: Embeddable trait is imported with the macro. impl #impl_generics Embeddable for #name #ty_generics #where_clause { - type Kind = #embed_kind; - type Error = EmbeddingGenerationError; + type Kind = rig::embeddings::embeddable::#embed_kind; + type Error = rig::embeddings::embeddable::EmbeddableError; fn embeddable(&self) -> Result<Vec<String>, Self::Error> { - vec![#(#embed_targets.clone()),*].embeddable() + let mut embed_targets = #embed_targets_quote; + + let custom_embed_targets = #custom_embed_targets_quote; - // let custom_embed_targets = vec![#( embeddable( #embed_targets ); ),*] - // .iter() - // .collect::<Result<Vec<_>, _>>()? - // .into_iter() - // .flatten(); + embed_targets.extend(custom_embed_targets); - // Ok(embed_targets.chain(custom_embed_targets).collect()) + Ok(embed_targets) } } }; @@ -96,95 +125,30 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result<Toke Ok(gen) } -// Adds bounds to where clause that force all fields tagged with #[embed] to implement the Embeddable trait. -fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { - let where_clause = generics.make_where_clause(); - - where_clause.predicates.push(parse_quote! { - #field_type: Embeddable - }); -} - -fn embed_kind(field: &syn::Field) -> syn::Result<syn::Expr> { - match &field.ty { - syn::Type::Path(path) => { - if path.path.segments.first().unwrap().ident == VEC_TYPE { - parse_str(MANY_EMBEDDING) - } else { - parse_str(SINGLE_EMBEDDING) +/// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is 1, +/// returns the kind of embedding that field should be. +/// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is greater than 1, +/// return ManyEmbedding. +fn embed_kind(data_struct: &DataStruct) -> syn::Result<syn::Expr> { + fn embed_kind(field: &syn::Field) -> syn::Result<syn::Expr> { + match &field.ty { + syn::Type::Path(path) => { + if path.path.segments.first().unwrap().ident == VEC_TYPE { + parse_str(MANY_EMBEDDING) + } else { + parse_str(SINGLE_EMBEDDING) + } } + _ => parse_str(SINGLE_EMBEDDING), } - _ => parse_str(SINGLE_EMBEDDING), - } -} - -trait StructParser { - /// Finds and returns fields with simple #[embed] attribute tags only. - fn basic_embed_fields(&self) -> impl Iterator<Item = syn::Field>; - /// Finds and returns fields with #[embed(embed_with = "...")] attribute tags only. - /// Also returns the attribute in question. - fn custom_embed_fields(&self) - -> syn::Result<impl Iterator<Item = (syn::Field, syn::ExprPath)>>; - - /// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is 1, - /// returns the kind of embedding that field should be. - /// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is greater than 1, - /// return ManyEmbedding. - fn embed_kind(&self) -> syn::Result<syn::Expr> { - let fields = self - .basic_embed_fields() - .chain(self.custom_embed_fields()?.map(|(f, _)| f)) - .collect::<Vec<_>>(); - - if fields.len() == 1 { - fields.iter().map(embed_kind).next().unwrap() - } else { - parse_str("ManyEmbedding") - } - } -} - -impl StructParser for DataStruct { - fn basic_embed_fields(&self) -> impl Iterator<Item = syn::Field> { - self.fields.clone().into_iter().filter(|field| { - field - .attrs - .clone() - .into_iter() - .any(|attribute| match attribute { - Attribute { - meta: Meta::Path(path), - .. - } => path.is_ident(EMBED), - _ => false, - }) - }) } - - fn custom_embed_fields( - &self, - ) -> syn::Result<impl Iterator<Item = (syn::Field, syn::ExprPath)>> { - Ok(self - .fields - .clone() - .into_iter() - .map(|field| { - field - .attrs - .clone() - .into_iter() - .map(|attribute| { - if attribute.is_custom()? { - Ok::<_, syn::Error>(Some((field.clone(), attribute.expand_tag()?))) - } else { - Ok(None) - } - }) - .collect::<Result<Vec<_>, _>>() - }) - .collect::<Result<Vec<_>, _>>()? - .into_iter() - .flatten() - .flatten()) + let fields = basic_embed_fields(data_struct) + .chain(custom_embed_fields(data_struct)?.map(|(f, _)| f)) + .collect::<Vec<_>>(); + + if fields.len() == 1 { + fields.iter().map(embed_kind).next().unwrap() + } else { + parse_str(MANY_EMBEDDING) } } diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index fba92f1c..d28a0d78 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -2,6 +2,7 @@ extern crate proc_macro; use proc_macro::TokenStream; use syn::{parse_macro_input, DeriveInput}; +mod basic; mod custom; mod embeddable; diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs new file mode 100644 index 00000000..01a99fd4 --- /dev/null +++ b/rig-core/src/embeddings/builder.rs @@ -0,0 +1,206 @@ +//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. +//! Only types that implement the [Embeddable] trait can be added to the [EmbeddingsBuilder]. +//! +//! # Example +//! ```rust +//! use std::env; +//! +//! use rig::{ +//! embeddings::EmbeddingsBuilder, +//! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, +//! }; +//! use rig_derive::Embed; +//! +//! #[derive(Embed)] +//! struct FakeDefinition { +//! id: String, +//! word: String, +//! #[embed] +//! definitions: Vec<String>, +//! } +//! // Create OpenAI client +//! let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); +//! let openai_client = Client::new(&openai_api_key); +//! +//! let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); +//! +//! let embeddings = EmbeddingsBuilder::new(model.clone()) +//! .documents(vec![ +//! FakeDefinition { +//! id: "doc0".to_string(), +//! word: "flurbo".to_string(), +//! definitions: vec![ +//! "A green alien that lives on cold planets.".to_string(), +//! "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() +//! ] +//! }, +//! FakeDefinition { +//! id: "doc1".to_string(), +//! word: "glarb-glarb".to_string(), +//! definitions: vec![ +//! "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), +//! "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() +//! ] +//! }, +//! FakeDefinition { +//! id: "doc2".to_string(), +//! word: "linglingdong".to_string(), +//! definitions: vec![ +//! "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), +//! "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() +//! ] +//! }, +//! ]) +//! .build() +//! .await?; +//! +//! // Use the generated embeddings +//! // ... +//! ``` + +use std::{cmp::max, collections::HashMap, marker::PhantomData}; + +use futures::{stream, StreamExt, TryStreamExt}; + +use crate::Embeddable; + +use super::{ + embeddable::{EmbeddingKind, ManyEmbedding, SingleEmbedding}, + embedding::{Embedding, EmbeddingError, EmbeddingModel}, +}; + +/// Builder for creating a collection of embeddings. +pub struct EmbeddingsBuilder<M: EmbeddingModel, D: Embeddable, K: EmbeddingKind> { + kind: PhantomData<K>, + model: M, + documents: Vec<(D, Vec<String>)>, +} + +impl<M: EmbeddingModel, D: Embeddable<Kind = K>, K: EmbeddingKind> EmbeddingsBuilder<M, D, K> { + /// Create a new embedding builder with the given embedding model + pub fn new(model: M) -> Self { + Self { + kind: PhantomData, + model, + documents: vec![], + } + } + + /// Add a document that implements `Embeddable` to the builder. + pub fn document(mut self, document: D) -> Result<Self, D::Error> { + let embed_targets = document.embeddable()?; + + self.documents.push((document, embed_targets)); + Ok(self) + } + + /// Add many documents that implement `Embeddable` to the builder. + pub fn documents(mut self, documents: Vec<D>) -> Result<Self, D::Error> { + for doc in documents.into_iter() { + let embed_targets = doc.embeddable()?; + + self.documents.push((doc, embed_targets)); + } + + Ok(self) + } +} + +impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> + EmbeddingsBuilder<M, D, ManyEmbedding> +{ + /// Generate embeddings for all documents in the builder. + /// The method only applies when documents in the builder each contain multiple embedding targets. + /// Returns a vector of tuples, where the first element is the document and the second element is the vector of embeddings. + pub async fn build(&self) -> Result<Vec<(D, Vec<Embedding>)>, EmbeddingError> { + // Use this for reference later to merge a document back with its embeddings. + let documents_map = self + .documents + .clone() + .into_iter() + .enumerate() + .map(|(id, (document, _))| (id, document)) + .collect::<HashMap<_, _>>(); + + let embeddings = stream::iter(self.documents.iter().enumerate()) + // Merge the embedding targets of each document into a single list of embedding targets. + .flat_map(|(i, (_, embed_targets))| { + stream::iter(embed_targets.iter().map(move |target| (i, target.clone()))) + }) + // Chunk them into N (the emebdding API limit per request). + .chunks(M::MAX_DOCUMENTS) + // Generate the embeddings for a chunk at a time. + .map(|docs| async { + let (document_indices, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + + Ok::<_, EmbeddingError>( + document_indices + .into_iter() + .zip(self.model.embed_documents(embed_targets).await?.into_iter()) + .collect::<Vec<_>>(), + ) + }) + .boxed() + // Parallelize the embeddings generation over 10 concurrent requests + .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) + .try_fold( + HashMap::new(), + |mut acc: HashMap<_, Vec<_>>, embeddings| async move { + embeddings.into_iter().for_each(|(i, embedding)| { + acc.entry(i).or_default().push(embedding); + }); + + Ok(acc) + }, + ) + .await? + .iter() + .fold(vec![], |mut acc, (i, embeddings_vec)| { + acc.push(( + documents_map.get(i).cloned().unwrap(), + embeddings_vec.clone(), + )); + acc + }); + + Ok(embeddings) + } +} + +impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> + EmbeddingsBuilder<M, D, SingleEmbedding> +{ + /// Generate embeddings for all documents in the builder. + /// The method only applies when documents in the builder each contain a single embedding target. + /// Returns a vector of tuples, where the first element is the document and the second element is the embedding. + pub async fn build(&self) -> Result<Vec<(D, Embedding)>, EmbeddingError> { + let embeddings = stream::iter( + self.documents + .clone() + .into_iter() + .map(|(document, embed_target)| (document, embed_target.first().cloned().unwrap())), + ) + // Chunk them into N (the emebdding API limit per request) + .chunks(M::MAX_DOCUMENTS) + // Generate the embeddings + .map(|docs| async { + let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + Ok::<_, EmbeddingError>( + documents + .into_iter() + .zip(self.model.embed_documents(embed_targets).await?.into_iter()) + .collect::<Vec<_>>(), + ) + }) + .boxed() + // Parallelize the embeddings generation over 10 concurrent requests + .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) + .try_fold(vec![], |mut acc, embeddings| async move { + acc.extend(embeddings); + Ok(acc) + }) + .await?; + + Ok(embeddings) + } +} diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index 08e8b316..80e5b689 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -1,67 +1,4 @@ //! The module defines the [Embeddable] trait, which must be implemented for types that can be embedded. -//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. -//! Only types that implement the [Embeddable] trait can be added to the [EmbeddingsBuilder]. -//! -//! # Example -//! ```rust -//! use std::env; -//! -//! use rig::{ -//! embeddings::EmbeddingsBuilder, -//! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, -//! }; -//! use rig_derive::Embed; -//! -//! #[derive(Embed)] -//! struct FakeDefinition { -//! id: String, -//! word: String, -//! #[embed] -//! definitions: Vec<String>, -//! } -//! // Create OpenAI client -//! let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); -//! let openai_client = Client::new(&openai_api_key); -//! -//! let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); -//! -//! let embeddings = EmbeddingsBuilder::new(model.clone()) -//! .documents(vec![ -//! FakeDefinition { -//! id: "doc0".to_string(), -//! word: "flurbo".to_string(), -//! definitions: vec![ -//! "A green alien that lives on cold planets.".to_string(), -//! "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() -//! ] -//! }, -//! FakeDefinition { -//! id: "doc1".to_string(), -//! word: "glarb-glarb".to_string(), -//! definitions: vec![ -//! "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), -//! "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() -//! ] -//! }, -//! FakeDefinition { -//! id: "doc2".to_string(), -//! word: "linglingdong".to_string(), -//! definitions: vec![ -//! "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), -//! "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() -//! ] -//! }, -//! ]) -//! .build() -//! .await?; -//! -//! // Use the generated embeddings -//! // ... -//! ``` - -use super::embedding::{Embedding, EmbeddingError, EmbeddingModel}; -use futures::{stream, StreamExt, TryStreamExt}; -use std::{cmp::max, collections::HashMap, marker::PhantomData}; /// The associated type `Kind` on the trait `Embeddable` must implement this trait. pub trait EmbeddingKind {} @@ -75,7 +12,7 @@ pub struct ManyEmbedding; impl EmbeddingKind for ManyEmbedding {} #[derive(Debug, thiserror::Error)] -pub enum EmbeddingGenerationError { +pub enum EmbeddableError { #[error("SerdeError: {0}")] SerdeError(#[from] serde_json::Error), } @@ -90,148 +27,12 @@ pub trait Embeddable { fn embeddable(&self) -> Result<Vec<String>, Self::Error>; } -/// Builder for creating a collection of embeddings. -pub struct EmbeddingsBuilder<M: EmbeddingModel, D: Embeddable, K: EmbeddingKind> { - kind: PhantomData<K>, - model: M, - documents: Vec<(D, Vec<String>)>, -} - -impl<M: EmbeddingModel, D: Embeddable<Kind = K>, K: EmbeddingKind> EmbeddingsBuilder<M, D, K> { - /// Create a new embedding builder with the given embedding model - pub fn new(model: M) -> Self { - Self { - kind: PhantomData, - model, - documents: vec![], - } - } - - /// Add a document that implements `Embeddable` to the builder. - pub fn document(mut self, document: D) -> Result<Self, D::Error> { - let embed_targets = document.embeddable()?; - - self.documents.push((document, embed_targets)); - Ok(self) - } - - /// Add many documents that implement `Embeddable` to the builder. - pub fn documents(mut self, documents: Vec<D>) -> Result<Self, D::Error> { - for doc in documents.into_iter() { - let embed_targets = doc.embeddable()?; - - self.documents.push((doc, embed_targets)); - } - - Ok(self) - } -} - -impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> - EmbeddingsBuilder<M, D, ManyEmbedding> -{ - /// Generate embeddings for all documents in the builder. - /// The method only applies when documents in the builder each contain multiple embedding targets. - /// Returns a vector of tuples, where the first element is the document and the second element is the vector of embeddings. - pub async fn build(&self) -> Result<Vec<(D, Vec<Embedding>)>, EmbeddingError> { - // Use this for reference later to merge a document back with its embeddings. - let documents_map = self - .documents - .clone() - .into_iter() - .enumerate() - .map(|(id, (document, _))| (id, document)) - .collect::<HashMap<_, _>>(); - - let embeddings = stream::iter(self.documents.iter().enumerate()) - // Merge the embedding targets of each document into a single list of embedding targets. - .flat_map(|(i, (_, embed_targets))| { - stream::iter(embed_targets.iter().map(move |target| (i, target.clone()))) - }) - // Chunk them into N (the emebdding API limit per request). - .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings for a chunk at a time. - .map(|docs| async { - let (document_indices, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - - Ok::<_, EmbeddingError>( - document_indices - .into_iter() - .zip(self.model.embed_documents(embed_targets).await?.into_iter()) - .collect::<Vec<_>>(), - ) - }) - .boxed() - // Parallelize the embeddings generation over 10 concurrent requests - .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - .try_fold( - HashMap::new(), - |mut acc: HashMap<_, Vec<_>>, embeddings| async move { - embeddings.into_iter().for_each(|(i, embedding)| { - acc.entry(i).or_default().push(embedding); - }); - - Ok(acc) - }, - ) - .await? - .iter() - .fold(vec![], |mut acc, (i, embeddings_vec)| { - acc.push(( - documents_map.get(i).cloned().unwrap(), - embeddings_vec.clone(), - )); - acc - }); - - Ok(embeddings) - } -} - -impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> - EmbeddingsBuilder<M, D, SingleEmbedding> -{ - /// Generate embeddings for all documents in the builder. - /// The method only applies when documents in the builder each contain a single embedding target. - /// Returns a vector of tuples, where the first element is the document and the second element is the embedding. - pub async fn build(&self) -> Result<Vec<(D, Embedding)>, EmbeddingError> { - let embeddings = stream::iter( - self.documents - .clone() - .into_iter() - .map(|(document, embed_target)| (document, embed_target.first().cloned().unwrap())), - ) - // Chunk them into N (the emebdding API limit per request) - .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings - .map(|docs| async { - let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - Ok::<_, EmbeddingError>( - documents - .into_iter() - .zip(self.model.embed_documents(embed_targets).await?.into_iter()) - .collect::<Vec<_>>(), - ) - }) - .boxed() - // Parallelize the embeddings generation over 10 concurrent requests - .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - .try_fold(vec![], |mut acc, embeddings| async move { - acc.extend(embeddings); - Ok(acc) - }) - .await?; - - Ok(embeddings) - } -} - ////////////////////////////////////////////////////// /// Implementations of Embeddable for common types /// ////////////////////////////////////////////////////// impl Embeddable for String { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.clone()]) @@ -240,7 +41,7 @@ impl Embeddable for String { impl Embeddable for i8 { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) @@ -249,7 +50,7 @@ impl Embeddable for i8 { impl Embeddable for i16 { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) @@ -258,7 +59,7 @@ impl Embeddable for i16 { impl Embeddable for i32 { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) @@ -267,7 +68,7 @@ impl Embeddable for i32 { impl Embeddable for i64 { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) @@ -276,7 +77,7 @@ impl Embeddable for i64 { impl Embeddable for i128 { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) @@ -285,7 +86,7 @@ impl Embeddable for i128 { impl Embeddable for f32 { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) @@ -294,7 +95,7 @@ impl Embeddable for f32 { impl Embeddable for f64 { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) @@ -303,7 +104,7 @@ impl Embeddable for f64 { impl Embeddable for bool { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) @@ -312,7 +113,7 @@ impl Embeddable for bool { impl Embeddable for char { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) @@ -336,19 +137,127 @@ impl<T: Embeddable> Embeddable for Vec<T> { #[cfg(test)] mod tests { - use super::{Embeddable, EmbeddingGenerationError, SingleEmbedding}; + use crate as rig; + use rig::embeddings::embeddable::{Embeddable, EmbeddableError}; use rig_derive::Embeddable; + use serde::Serialize; + + fn serialize(definition: Definition) -> Result<Vec<String>, EmbeddableError> { + Ok(vec![ + serde_json::to_string(&definition).map_err(EmbeddableError::SerdeError)? + ]) + } #[derive(Embeddable)] struct FakeDefinition { id: String, + word: String, + #[embed(embed_with = "serialize")] + definition: Definition, + } + + #[derive(Serialize, Clone)] + struct Definition { + word: String, + link: String, + speech: String, + } + + #[test] + fn test_custom_embed() { + let fake_definition = FakeDefinition { + id: "doc1".to_string(), + word: "house".to_string(), + definition: Definition { + speech: "noun".to_string(), + word: "a building in which people live; residence for human beings.".to_string(), + link: "https://www.dictionary.com/browse/house".to_string(), + }, + }; + + assert_eq!( + fake_definition.embeddable().unwrap(), + vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] + ) + } + + #[derive(Embeddable)] + struct FakeDefinition2 { + id: String, + word: String, #[embed] definition: String, } #[test] - fn test_missing_embed_fields() {} + fn test_simple_embed() { + let fake_definition = FakeDefinition2 { + id: "doc1".to_string(), + word: "house".to_string(), + definition: "a building in which people live; residence for human beings.".to_string(), + }; + + assert_eq!( + fake_definition.embeddable().unwrap(), + vec!["a building in which people live; residence for human beings.".to_string()] + ); + + assert!(false) + } + + #[derive(Embeddable)] + struct Company { + id: String, + company: String, + #[embed] + employee_ages: Vec<i32>, + } + + #[test] + fn test_multiple_embed() { + let company = Company { + id: "doc1".to_string(), + company: "Google".to_string(), + employee_ages: vec![25, 30, 35, 40], + }; + + assert_eq!( + company.embeddable().unwrap(), + vec![ + "25".to_string(), + "30".to_string(), + "35".to_string(), + "40".to_string() + ] + ); + } + + #[derive(Embeddable)] + struct Company2 { + id: String, + #[embed] + company: String, + #[embed] + employee_ages: Vec<i32>, + } #[test] - fn test_empty_custom_function() {} + fn test_many_embed() { + let company = Company2 { + id: "doc1".to_string(), + company: "Google".to_string(), + employee_ages: vec![25, 30, 35, 40], + }; + + assert_eq!( + company.embeddable().unwrap(), + vec![ + "Google".to_string(), + "25".to_string(), + "30".to_string(), + "35".to_string(), + "40".to_string() + ] + ); + } } diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 33526769..37e720cb 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -3,5 +3,6 @@ //! natural language processing (NLP) tasks such as text classification, information retrieval, //! and document similarity. +pub mod builder; pub mod embeddable; pub mod embedding; diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 33464939..fa798205 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -14,9 +14,7 @@ use crate::{ agent::AgentBuilder, completion::{self, CompletionError}, embeddings::{ - self, - embeddable::{Embeddable, EmbeddingsBuilder}, - embedding::EmbeddingError, + self, builder::EmbeddingsBuilder, embeddable::Embeddable, embedding::EmbeddingError, }, extractor::ExtractorBuilder, json_utils, diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index d23a46bf..d21bf8fb 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -13,7 +13,8 @@ use crate::{ completion::{self, CompletionError, CompletionRequest}, embeddings::{ self, - embeddable::{Embeddable, EmbeddingsBuilder}, + builder::EmbeddingsBuilder, + embeddable::Embeddable, embedding::{Embedding, EmbeddingError}, }, extractor::ExtractorBuilder, diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index bf7f2a33..956ace1c 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -2,10 +2,7 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; -use rig::embeddings::{ - embeddable::{EmbeddingGenerationError, SingleEmbedding}, - embedding::Embedding, -}; +use rig::embeddings::embedding::Embedding; use rig::Embeddable; use serde::Deserialize; diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index e29a2b77..1042e34c 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -5,7 +5,7 @@ use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDe use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::{embeddable::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 7ae8c757..630acc1a 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -3,7 +3,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definitions, schema}; use rig::{ - embeddings::{embeddable::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, }; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 0381429a..8d65e37a 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,7 +4,7 @@ use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::{embeddable::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index dab94eb9..e087c3d6 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,12 +1,12 @@ use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; +use rig::embeddings::embeddable::EmbeddableError; use rig::providers::openai::TEXT_EMBEDDING_ADA_002; use serde::{Deserialize, Serialize}; use std::env; use rig::Embeddable; use rig::{ - embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, SingleEmbedding}, - providers::openai::Client, + embeddings::builder::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; @@ -20,6 +20,12 @@ struct FakeDefinition { definition: String, } +#[derive(Clone, Deserialize, Debug, Serialize)] +struct Link { + word: String, + link: String, +} + // Shape of the document to be stored in MongoDB, with embeddings. #[derive(Serialize, Debug)] struct Document { @@ -56,15 +62,15 @@ async fn main() -> Result<(), anyhow::Error> { let fake_definitions = vec![ FakeDefinition { id: "doc0".to_string(), - definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string() + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), }, FakeDefinition { id: "doc1".to_string(), - definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string() + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), }, FakeDefinition { id: "doc2".to_string(), - definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), } ]; @@ -75,11 +81,13 @@ async fn main() -> Result<(), anyhow::Error> { let mongo_documents = embeddings .iter() - .map(|(FakeDefinition { id, definition }, embedding)| Document { - id: id.clone(), - definition: definition.clone(), - embedding: embedding.vec.clone(), - }) + .map( + |(FakeDefinition { id, definition, .. }, embedding)| Document { + id: id.clone(), + definition: definition.clone(), + embedding: embedding.vec.clone(), + }, + ) .collect::<Vec<_>>(); match collection.insert_many(mongo_documents, None).await { From 8c993dd943ecd583224664a87cf706f55c05d248 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Tue, 15 Oct 2024 09:59:04 -0400 Subject: [PATCH 24/91] fix: revert logging change --- rig-core/src/vector_store/in_memory_store.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 40ad018e..ec497ac4 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -52,7 +52,7 @@ impl<D: Serialize + Eq> InMemoryVectorStore<D> { tracing::info!(target: "rig", "Selected documents: {}", docs.iter() - .map(|Reverse(RankingItem(distance, id, _, embed_doc))| format!("{} ({}). Specific match: {}", id, distance, embed_doc)) + .map(|Reverse(RankingItem(distance, id, _, _))| format!("{} ({})", id, distance)) .collect::<Vec<String>>() .join(", ") ); From 5a8c3612fbda2af47ffce55a2186ca46d3f396c9 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Tue, 15 Oct 2024 14:28:03 -0400 Subject: [PATCH 25/91] feat: handle tools with embeddingsbuilder --- rig-core/examples/calculator_chatbot.rs | 17 +- rig-core/examples/rag_dynamic_tools.rs | 13 +- rig-core/rig-core-derive/src/custom.rs | 2 +- rig-core/rig-core-derive/src/embeddable.rs | 175 ++++++++++-------- rig-core/src/embeddings/builder.rs | 6 +- rig-core/src/embeddings/embeddable.rs | 46 +++-- rig-core/src/embeddings/mod.rs | 1 + rig-core/src/embeddings/tool.rs | 25 +++ rig-core/src/tool.rs | 18 +- rig-core/src/vector_store/in_memory_store.rs | 2 +- rig-mongodb/examples/vector_search_mongodb.rs | 1 - 11 files changed, 174 insertions(+), 132 deletions(-) create mode 100644 rig-core/src/embeddings/tool.rs diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index fb168a08..0b994265 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -2,7 +2,7 @@ use anyhow::Result; use rig::{ cli_chatbot::cli_chatbot, completion::ToolDefinition, - embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, + embeddings::builder::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, @@ -25,7 +25,7 @@ struct MathError; #[error("Init error")] struct InitError; -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Clone)] struct Add; impl Tool for Add { const NAME: &'static str = "add"; @@ -77,7 +77,7 @@ impl ToolEmbedding for Add { fn context(&self) -> Self::Context {} } -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Clone)] struct Subtract; impl Tool for Subtract { const NAME: &'static str = "subtract"; @@ -247,7 +247,7 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .tools(&toolset)? + .documents(toolset.embedabble_tools()?)? .build() .await?; @@ -255,13 +255,8 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .map( - |DocumentEmbeddings { - id, - document, - embeddings, - }| { (id, document, embeddings) }, - ) + .enumerate() + .map(|(i, (tool, embedding))| (i.to_string(), tool, vec![embedding])) .collect(), )? .index(embedding_model); diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index cdf6b65e..51c56ca8 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -1,7 +1,7 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, - embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, + embeddings::builder::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, @@ -156,7 +156,7 @@ async fn main() -> Result<(), anyhow::Error> { .build(); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .tools(&toolset)? + .documents(toolset.embedabble_tools()?)? .build() .await?; @@ -164,13 +164,8 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .map( - |DocumentEmbeddings { - id, - document, - embeddings, - }| { (id, document, embeddings) }, - ) + .enumerate() + .map(|(i, (tool, embedding))| (i.to_string(), tool, vec![embedding])) .collect(), )? .index(embedding_model); diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs index 8926aebf..194be085 100644 --- a/rig-core/rig-core-derive/src/custom.rs +++ b/rig-core/rig-core-derive/src/custom.rs @@ -6,7 +6,7 @@ use crate::EMBED; const EMBED_WITH: &str = "embed_with"; /// Finds and returns fields with #[embed(embed_with = "...")] attribute tags only. -/// Also returns the attribute in question. +/// Also returns the "..." part of the tag (ie. the custom function). pub(crate) fn custom_embed_fields( data_struct: &syn::DataStruct, ) -> syn::Result<impl Iterator<Item = (syn::Field, syn::ExprPath)>> { diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index 914945ec..cd5a201f 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -10,113 +10,58 @@ const VEC_TYPE: &str = "Vec"; const MANY_EMBEDDING: &str = "ManyEmbedding"; const SINGLE_EMBEDDING: &str = "SingleEmbedding"; -pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result<TokenStream> { - let name = &input.ident; +pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result<TokenStream> { + let data = &input.data; + let generics = &mut input.generics; - // Handles fields tagged with #[embed] - let embed_targets = match &input.data { - syn::Data::Struct(data_struct) => basic_embed_fields(data_struct) - .map(|field| { - add_struct_bounds(&mut input.generics, &field.ty); - - let field_name = field.ident; + let (target_stream, embed_kind) = match data { + syn::Data::Struct(data_struct) => { + let basic_targets = data_struct.basic(generics); + let custom_targets = data_struct.custom()?; + // Determine whether the Embeddable::Kind should be SinleEmbedding or ManyEmbedding + ( quote! { - self.#field_name - } - }) - .collect::<Vec<_>>(), - _ => { - return Err(syn::Error::new_spanned( - name, - "Embeddable derive macro should only be used on structs", - )) + let mut embed_targets = #basic_targets; + embed_targets.extend(#custom_targets) + }, + embed_kind(data_struct)?, + ) } - }; - - let embed_targets_quote = if !embed_targets.is_empty() { - quote! { - vec![#(#embed_targets.embeddable()),*] - .into_iter() - .collect::<Result<Vec<_>, _>>()? - .into_iter() - .flatten() - .collect::<Vec<_>>() - } - } else { - quote! { - vec![] - } - }; - - // Handles fields tagged with #[embed(embed_with = "...")] - let custom_embed_targets = match &input.data { - syn::Data::Struct(data_struct) => custom_embed_fields(data_struct)? - .map(|(field, custom_func_path)| { - let field_name = field.ident; - - quote! { - #custom_func_path(self.#field_name.clone()) - } - }) - .collect::<Vec<_>>(), _ => { return Err(syn::Error::new_spanned( - name, + input, "Embeddable derive macro should only be used on structs", )) } }; - let custom_embed_targets_quote = if !custom_embed_targets.is_empty() { - quote! { - vec![#(#custom_embed_targets),*] - .into_iter() - .collect::<Result<Vec<_>, _>>()? - .into_iter() - .flatten() - .collect::<Vec<_>>() - } - } else { - quote! { - vec![] - } - }; - // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. // ie. do not implement Embeddable trait for the struct. - if embed_targets.is_empty() && custom_embed_targets.is_empty() { + if target_stream.is_empty() { return Ok(TokenStream::new()); } - // Determine whether the Embeddable::Kind should be SinleEmbedding or ManyEmbedding. - let embed_kind = match &input.data { - syn::Data::Struct(data_struct) => embed_kind(data_struct)?, - _ => { - return Err(syn::Error::new_spanned( - name, - "Embeddable derive macro should only be used on structs", - )) - } - }; - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let name = &input.ident; + let gen = quote! { // Note: Embeddable trait is imported with the macro. impl #impl_generics Embeddable for #name #ty_generics #where_clause { type Kind = rig::embeddings::embeddable::#embed_kind; - type Error = rig::embeddings::embeddable::EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { - let mut embed_targets = #embed_targets_quote; + fn embeddable(&self) -> Result<Vec<String>, rig::embeddings::embeddable::EmbeddableError> { + #target_stream; - let custom_embed_targets = #custom_embed_targets_quote; + let targets = embed_targets.into_iter() + .collect::<Result<Vec<_>, _>>()? + .into_iter() + .flatten() + .collect::<Vec<_>>(); - embed_targets.extend(custom_embed_targets); - - Ok(embed_targets) + Ok(targets) } } }; @@ -152,3 +97,71 @@ fn embed_kind(data_struct: &DataStruct) -> syn::Result<syn::Expr> { parse_str(MANY_EMBEDDING) } } + +trait StructParser { + // Handles fields tagged with #[embed] + fn basic(&self, generics: &mut syn::Generics) -> TokenStream; + + // Handles fields tagged with #[embed(embed_with = "...")] + fn custom(&self) -> syn::Result<TokenStream>; +} + +impl StructParser for DataStruct { + fn basic(&self, generics: &mut syn::Generics) -> TokenStream { + let embed_targets = basic_embed_fields(self) + // Iterate over every field tagged with #[embed] + .map(|field| { + add_struct_bounds(generics, &field.ty); + + let field_name = field.ident; + + quote! { + self.#field_name + } + }) + .collect::<Vec<_>>(); + + if !embed_targets.is_empty() { + quote! { + vec![#(#embed_targets.embeddable()),*] + // .into_iter() + // .collect::<Result<Vec<_>, _>>()? + // .into_iter() + // .flatten() + // .collect::<Vec<_>>() + } + } else { + quote! { + vec![] + } + } + } + + fn custom(&self) -> syn::Result<TokenStream> { + let embed_targets = custom_embed_fields(self)? + // Iterate over every field tagged with #[embed(embed_with = "...")] + .map(|(field, custom_func_path)| { + let field_name = field.ident; + + quote! { + #custom_func_path(self.#field_name.clone()) + } + }) + .collect::<Vec<_>>(); + + Ok(if !embed_targets.is_empty() { + quote! { + vec![#(#embed_targets),*] + // .into_iter() + // .collect::<Result<Vec<_>, _>>()? + // .into_iter() + // .flatten() + // .collect::<Vec<_>>() + } + } else { + quote! { + vec![] + } + }) + } +} diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 01a99fd4..1a699833 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -65,7 +65,7 @@ use futures::{stream, StreamExt, TryStreamExt}; use crate::Embeddable; use super::{ - embeddable::{EmbeddingKind, ManyEmbedding, SingleEmbedding}, + embeddable::{EmbeddableError, EmbeddingKind, ManyEmbedding, SingleEmbedding}, embedding::{Embedding, EmbeddingError, EmbeddingModel}, }; @@ -87,7 +87,7 @@ impl<M: EmbeddingModel, D: Embeddable<Kind = K>, K: EmbeddingKind> EmbeddingsBui } /// Add a document that implements `Embeddable` to the builder. - pub fn document(mut self, document: D) -> Result<Self, D::Error> { + pub fn document(mut self, document: D) -> Result<Self, EmbeddableError> { let embed_targets = document.embeddable()?; self.documents.push((document, embed_targets)); @@ -95,7 +95,7 @@ impl<M: EmbeddingModel, D: Embeddable<Kind = K>, K: EmbeddingKind> EmbeddingsBui } /// Add many documents that implement `Embeddable` to the builder. - pub fn documents(mut self, documents: Vec<D>) -> Result<Self, D::Error> { + pub fn documents(mut self, documents: Vec<D>) -> Result<Self, EmbeddableError> { for doc in documents.into_iter() { let embed_targets = doc.embeddable()?; diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index 80e5b689..6b1996e3 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -22,9 +22,8 @@ pub enum EmbeddableError { /// If the type `Kind` is `SingleEmbedding`, the list of strings contains a single item, otherwise, the list can contain many items. pub trait Embeddable { type Kind: EmbeddingKind; - type Error: std::error::Error; - fn embeddable(&self) -> Result<Vec<String>, Self::Error>; + fn embeddable(&self) -> Result<Vec<String>, EmbeddableError>; } ////////////////////////////////////////////////////// @@ -32,99 +31,98 @@ pub trait Embeddable { ////////////////////////////////////////////////////// impl Embeddable for String { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { Ok(vec![self.clone()]) } } impl Embeddable for i8 { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for i16 { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for i32 { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for i64 { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for i128 { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for f32 { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for f64 { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for bool { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for char { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { Ok(vec![self.to_string()]) } } +impl Embeddable for serde_json::Value { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { + Ok(vec![ + serde_json::to_string(self).map_err(EmbeddableError::SerdeError)? + ]) + } +} + impl<T: Embeddable> Embeddable for Vec<T> { type Kind = ManyEmbedding; - type Error = T::Error; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { Ok(self .iter() .map(|i| i.embeddable()) diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 37e720cb..d590b7d0 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -6,3 +6,4 @@ pub mod builder; pub mod embeddable; pub mod embedding; +pub mod tool; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs new file mode 100644 index 00000000..c369fd94 --- /dev/null +++ b/rig-core/src/embeddings/tool.rs @@ -0,0 +1,25 @@ +use crate::{self as rig, tool::ToolEmbeddingDyn}; +use rig::embeddings::embeddable::Embeddable; +use rig_derive::Embeddable; +use serde::Serialize; + +use super::embeddable::EmbeddableError; + +/// Used by EmbeddingsBuilder to embed anything that implements ToolEmbedding. +#[derive(Embeddable, Clone, Serialize, Default, Eq, PartialEq)] +pub struct EmbeddableTool { + name: String, + #[embed] + context: serde_json::Value, +} + +impl EmbeddableTool { + /// Convert item that implements ToolEmbedding to an EmbeddableTool. + pub fn try_from(tool: &Box<dyn ToolEmbeddingDyn>) -> Result<Self, EmbeddableError> { + Ok(EmbeddableTool { + name: tool.name(), + context: serde_json::to_value(tool.context().map_err(EmbeddableError::SerdeError)?) + .map_err(EmbeddableError::SerdeError)?, + }) + } +} diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 98394ecf..181751f5 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -3,7 +3,10 @@ use std::{collections::HashMap, pin::Pin}; use futures::Future; use serde::{Deserialize, Serialize}; -use crate::completion::{self, ToolDefinition}; +use crate::{ + completion::{self, ToolDefinition}, + embeddings::{embeddable::EmbeddableError, tool::EmbeddableTool}, +}; #[derive(Debug, thiserror::Error)] pub enum ToolError { @@ -323,6 +326,19 @@ impl ToolSet { } Ok(docs) } + + pub fn embedabble_tools(&self) -> Result<Vec<EmbeddableTool>, EmbeddableError> { + self.tools + .values() + .filter_map(|tool_type| { + if let ToolType::Embedding(tool) = tool_type { + Some(EmbeddableTool::try_from(tool)) + } else { + None + } + }) + .collect::<Result<Vec<_>, _>>() + } } #[derive(Default)] diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index f1e6ad0a..9bbc85f1 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -189,7 +189,7 @@ impl<M: EmbeddingModel + std::marker::Sync, D: Serialize + Sync + Send + Eq> Vec mod tests { use std::cmp::Reverse; - use crate::embeddings::Embedding; + use crate::embeddings::embedding::Embedding; use super::{InMemoryVectorStore, RankingItem}; diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index e087c3d6..a816c9c0 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,5 +1,4 @@ use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; -use rig::embeddings::embeddable::EmbeddableError; use rig::providers::openai::TEXT_EMBEDDING_ADA_002; use serde::{Deserialize, Serialize}; use std::env; From dc89e54f9e0bc7c98da3dbdca014691f42b764ca Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Tue, 15 Oct 2024 15:27:37 -0400 Subject: [PATCH 26/91] bug(macro): fix error when embed tags missing --- rig-core/rig-core-derive/src/custom.rs | 29 ++++---- rig-core/rig-core-derive/src/embeddable.rs | 82 ++++++++++++---------- rig-core/src/embeddings/tool.rs | 2 +- rig-core/src/tool.rs | 2 +- 4 files changed, 61 insertions(+), 54 deletions(-) diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs index 194be085..a795026c 100644 --- a/rig-core/rig-core-derive/src/custom.rs +++ b/rig-core/rig-core-derive/src/custom.rs @@ -9,29 +9,32 @@ const EMBED_WITH: &str = "embed_with"; /// Also returns the "..." part of the tag (ie. the custom function). pub(crate) fn custom_embed_fields( data_struct: &syn::DataStruct, -) -> syn::Result<impl Iterator<Item = (syn::Field, syn::ExprPath)>> { - Ok(data_struct +) -> syn::Result<Vec<(syn::Field, syn::ExprPath)>> { + data_struct .fields .clone() .into_iter() - .map(|field| { + .filter_map(|field| { field .attrs .clone() .into_iter() - .map(|attribute| { - if attribute.is_custom()? { - Ok::<_, syn::Error>(Some((field.clone(), attribute.expand_tag()?))) - } else { - Ok(None) + .filter_map(|attribute| { + match attribute.is_custom() { + Ok(true) => { + match attribute.expand_tag() { + Ok(path) => Some(Ok((field.clone(), path))), + Err(e) => Some(Err(e)), + } + }, + Ok(false) => None, + Err(e) => Some(Err(e)) } }) - .collect::<Result<Vec<_>, _>>() + .next() }) - .collect::<Result<Vec<_>, _>>()? - .into_iter() - .flatten() - .flatten()) + .collect::<Result<Vec<_>, _>>() + } trait CustomAttributeParser { diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index cd5a201f..d44bb7bf 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -11,15 +11,25 @@ const MANY_EMBEDDING: &str = "ManyEmbedding"; const SINGLE_EMBEDDING: &str = "SingleEmbedding"; pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result<TokenStream> { + let name = &input.ident; let data = &input.data; let generics = &mut input.generics; let (target_stream, embed_kind) = match data { syn::Data::Struct(data_struct) => { - let basic_targets = data_struct.basic(generics); - let custom_targets = data_struct.custom()?; + let (basic_targets, basic_target_size) = data_struct.basic(generics); + let (custom_targets, custom_target_size) = data_struct.custom()?; + + // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. + // ie. do not implement Embeddable trait for the struct. + if basic_target_size + custom_target_size == 0 { + return Err(syn::Error::new_spanned( + name, + "Add at least one field tagged with #[embed] or #[embed(embed_with = \"...\")].", + )) + } - // Determine whether the Embeddable::Kind should be SinleEmbedding or ManyEmbedding + // Determine whether the Embeddable::Kind should be SingleEmbedding or ManyEmbedding ( quote! { let mut embed_targets = #basic_targets; @@ -36,16 +46,8 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu } }; - // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. - // ie. do not implement Embeddable trait for the struct. - if target_stream.is_empty() { - return Ok(TokenStream::new()); - } - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let name = &input.ident; - let gen = quote! { // Note: Embeddable trait is imported with the macro. @@ -88,7 +90,7 @@ fn embed_kind(data_struct: &DataStruct) -> syn::Result<syn::Expr> { } } let fields = basic_embed_fields(data_struct) - .chain(custom_embed_fields(data_struct)?.map(|(f, _)| f)) + .chain(custom_embed_fields(data_struct)?.into_iter().map(|(f, _)| f)) .collect::<Vec<_>>(); if fields.len() == 1 { @@ -100,14 +102,14 @@ fn embed_kind(data_struct: &DataStruct) -> syn::Result<syn::Expr> { trait StructParser { // Handles fields tagged with #[embed] - fn basic(&self, generics: &mut syn::Generics) -> TokenStream; + fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize); // Handles fields tagged with #[embed(embed_with = "...")] - fn custom(&self) -> syn::Result<TokenStream>; + fn custom(&self) -> syn::Result<(TokenStream, usize)>; } impl StructParser for DataStruct { - fn basic(&self, generics: &mut syn::Generics) -> TokenStream { + fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize) { let embed_targets = basic_embed_fields(self) // Iterate over every field tagged with #[embed] .map(|field| { @@ -122,25 +124,26 @@ impl StructParser for DataStruct { .collect::<Vec<_>>(); if !embed_targets.is_empty() { - quote! { - vec![#(#embed_targets.embeddable()),*] - // .into_iter() - // .collect::<Result<Vec<_>, _>>()? - // .into_iter() - // .flatten() - // .collect::<Vec<_>>() - } + ( + quote! { + vec![#(#embed_targets.embeddable()),*] + }, + embed_targets.len() + ) } else { - quote! { - vec![] - } + ( + quote! { + vec![] + }, + 0 + ) } } - fn custom(&self) -> syn::Result<TokenStream> { + fn custom(&self) -> syn::Result<(TokenStream, usize)> { let embed_targets = custom_embed_fields(self)? // Iterate over every field tagged with #[embed(embed_with = "...")] - .map(|(field, custom_func_path)| { + .into_iter().map(|(field, custom_func_path)| { let field_name = field.ident; quote! { @@ -150,18 +153,19 @@ impl StructParser for DataStruct { .collect::<Vec<_>>(); Ok(if !embed_targets.is_empty() { - quote! { - vec![#(#embed_targets),*] - // .into_iter() - // .collect::<Result<Vec<_>, _>>()? - // .into_iter() - // .flatten() - // .collect::<Vec<_>>() - } + ( + quote! { + vec![#(#embed_targets),*] + }, + embed_targets.len() + ) } else { - quote! { - vec![] - } + ( + quote! { + vec![] + }, + 0 + ) }) } } diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index c369fd94..5e8ecd35 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -15,7 +15,7 @@ pub struct EmbeddableTool { impl EmbeddableTool { /// Convert item that implements ToolEmbedding to an EmbeddableTool. - pub fn try_from(tool: &Box<dyn ToolEmbeddingDyn>) -> Result<Self, EmbeddableError> { + pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result<Self, EmbeddableError> { Ok(EmbeddableTool { name: tool.name(), context: serde_json::to_value(tool.context().map_err(EmbeddableError::SerdeError)?) diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 181751f5..d6f05ffc 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -332,7 +332,7 @@ impl ToolSet { .values() .filter_map(|tool_type| { if let ToolType::Embedding(tool) = tool_type { - Some(EmbeddableTool::try_from(tool)) + Some(EmbeddableTool::try_from(&**tool)) } else { None } From ae66d082487b5b7b2a0bb7a456ae86f3433666e9 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Tue, 15 Oct 2024 15:28:35 -0400 Subject: [PATCH 27/91] style: cargo fmt --- rig-core/rig-core-derive/src/custom.rs | 19 +++++++------------ rig-core/rig-core-derive/src/embeddable.rs | 19 ++++++++++++------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs index a795026c..de754372 100644 --- a/rig-core/rig-core-derive/src/custom.rs +++ b/rig-core/rig-core-derive/src/custom.rs @@ -19,22 +19,17 @@ pub(crate) fn custom_embed_fields( .attrs .clone() .into_iter() - .filter_map(|attribute| { - match attribute.is_custom() { - Ok(true) => { - match attribute.expand_tag() { - Ok(path) => Some(Ok((field.clone(), path))), - Err(e) => Some(Err(e)), - } - }, - Ok(false) => None, - Err(e) => Some(Err(e)) - } + .filter_map(|attribute| match attribute.is_custom() { + Ok(true) => match attribute.expand_tag() { + Ok(path) => Some(Ok((field.clone(), path))), + Err(e) => Some(Err(e)), + }, + Ok(false) => None, + Err(e) => Some(Err(e)), }) .next() }) .collect::<Result<Vec<_>, _>>() - } trait CustomAttributeParser { diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index d44bb7bf..88503360 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -26,7 +26,7 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu return Err(syn::Error::new_spanned( name, "Add at least one field tagged with #[embed] or #[embed(embed_with = \"...\")].", - )) + )); } // Determine whether the Embeddable::Kind should be SingleEmbedding or ManyEmbedding @@ -90,7 +90,11 @@ fn embed_kind(data_struct: &DataStruct) -> syn::Result<syn::Expr> { } } let fields = basic_embed_fields(data_struct) - .chain(custom_embed_fields(data_struct)?.into_iter().map(|(f, _)| f)) + .chain( + custom_embed_fields(data_struct)? + .into_iter() + .map(|(f, _)| f), + ) .collect::<Vec<_>>(); if fields.len() == 1 { @@ -128,14 +132,14 @@ impl StructParser for DataStruct { quote! { vec![#(#embed_targets.embeddable()),*] }, - embed_targets.len() + embed_targets.len(), ) } else { ( quote! { vec![] }, - 0 + 0, ) } } @@ -143,7 +147,8 @@ impl StructParser for DataStruct { fn custom(&self) -> syn::Result<(TokenStream, usize)> { let embed_targets = custom_embed_fields(self)? // Iterate over every field tagged with #[embed(embed_with = "...")] - .into_iter().map(|(field, custom_func_path)| { + .into_iter() + .map(|(field, custom_func_path)| { let field_name = field.ident; quote! { @@ -157,14 +162,14 @@ impl StructParser for DataStruct { quote! { vec![#(#embed_targets),*] }, - embed_targets.len() + embed_targets.len(), ) } else { ( quote! { vec![] }, - 0 + 0, ) }) } From 4305952beeff3b62039b8e9374ff2855d9142c7e Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Tue, 15 Oct 2024 15:34:11 -0400 Subject: [PATCH 28/91] fix(tests): clippy --- rig-core/src/embeddings/embeddable.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index 6b1996e3..5faa8b2d 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -173,6 +173,11 @@ mod tests { }, }; + println!( + "FakeDefinition: {}, {}", + fake_definition.id, fake_definition.word + ); + assert_eq!( fake_definition.embeddable().unwrap(), vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] @@ -188,13 +193,18 @@ mod tests { } #[test] - fn test_simple_embed() { + fn test_single_embed() { let fake_definition = FakeDefinition2 { id: "doc1".to_string(), word: "house".to_string(), definition: "a building in which people live; residence for human beings.".to_string(), }; + println!( + "FakeDefinition2: {}, {}", + fake_definition.id, fake_definition.word + ); + assert_eq!( fake_definition.embeddable().unwrap(), vec!["a building in which people live; residence for human beings.".to_string()] @@ -219,6 +229,8 @@ mod tests { employee_ages: vec![25, 30, 35, 40], }; + println!("Company: {}, {}", company.id, company.company); + assert_eq!( company.embeddable().unwrap(), vec![ @@ -247,6 +259,8 @@ mod tests { employee_ages: vec![25, 30, 35, 40], }; + println!("Company2: {}", company.id); + assert_eq!( company.embeddable().unwrap(), vec![ From 24e3b9867382f53aab7cb6c87d2ace46a05c23ce Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Tue, 15 Oct 2024 16:04:24 -0400 Subject: [PATCH 29/91] docs&revert: revert embeddable trait error type, add docstrings --- rig-core/rig-core-derive/src/embeddable.rs | 3 +- rig-core/src/embeddings/builder.rs | 85 +++++++++++----------- rig-core/src/embeddings/embeddable.rs | 60 +++++++++++---- rig-core/src/tool.rs | 3 + 4 files changed, 96 insertions(+), 55 deletions(-) diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index 88503360..e3fe7f94 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -53,8 +53,9 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu impl #impl_generics Embeddable for #name #ty_generics #where_clause { type Kind = rig::embeddings::embeddable::#embed_kind; + type Error = rig::embeddings::embeddable::EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, rig::embeddings::embeddable::EmbeddableError> { + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { #target_stream; let targets = embed_targets.into_iter() diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 1a699833..435ecc2f 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -6,18 +6,23 @@ //! use std::env; //! //! use rig::{ -//! embeddings::EmbeddingsBuilder, -//! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, +//! embeddings::builder::EmbeddingsBuilder, +//! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, +//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, +//! Embeddable, //! }; -//! use rig_derive::Embed; +//! use serde::{Deserialize, Serialize}; //! -//! #[derive(Embed)] +//! // Shape of data that needs to be RAG'ed. +//! // The definition field will be used to generate embeddings. +//! #[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] //! struct FakeDefinition { -//! id: String, -//! word: String, -//! #[embed] -//! definitions: Vec<String>, +//! id: String, +//! word: String, +//! #[embed] +//! definitions: Vec<String>, //! } +//! //! // Create OpenAI client //! let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); //! let openai_client = Client::new(&openai_api_key); @@ -25,34 +30,34 @@ //! let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); //! //! let embeddings = EmbeddingsBuilder::new(model.clone()) -//! .documents(vec![ -//! FakeDefinition { -//! id: "doc0".to_string(), -//! word: "flurbo".to_string(), -//! definitions: vec![ -//! "A green alien that lives on cold planets.".to_string(), -//! "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() -//! ] -//! }, -//! FakeDefinition { -//! id: "doc1".to_string(), -//! word: "glarb-glarb".to_string(), -//! definitions: vec![ -//! "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), -//! "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() -//! ] -//! }, -//! FakeDefinition { -//! id: "doc2".to_string(), -//! word: "linglingdong".to_string(), -//! definitions: vec![ -//! "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), -//! "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() -//! ] -//! }, -//! ]) -//! .build() -//! .await?; +//! .documents(vec![ +//! FakeDefinition { +//! id: "doc0".to_string(), +//! word: "flurbo".to_string(), +//! definitions: vec![ +//! "A green alien that lives on cold planets.".to_string(), +//! "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() +//! ] +//! }, +//! FakeDefinition { +//! id: "doc1".to_string(), +//! word: "glarb-glarb".to_string(), +//! definitions: vec![ +//! "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), +//! "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() +//! ] +//! }, +//! FakeDefinition { +//! id: "doc2".to_string(), +//! word: "linglingdong".to_string(), +//! definitions: vec![ +//! "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), +//! "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() +//! ] +//! }, +//! ])? +//! .build() +//! .await?; //! //! // Use the generated embeddings //! // ... @@ -62,10 +67,8 @@ use std::{cmp::max, collections::HashMap, marker::PhantomData}; use futures::{stream, StreamExt, TryStreamExt}; -use crate::Embeddable; - use super::{ - embeddable::{EmbeddableError, EmbeddingKind, ManyEmbedding, SingleEmbedding}, + embeddable::{Embeddable, EmbeddableError, EmbeddingKind, ManyEmbedding, SingleEmbedding}, embedding::{Embedding, EmbeddingError, EmbeddingModel}, }; @@ -87,7 +90,7 @@ impl<M: EmbeddingModel, D: Embeddable<Kind = K>, K: EmbeddingKind> EmbeddingsBui } /// Add a document that implements `Embeddable` to the builder. - pub fn document(mut self, document: D) -> Result<Self, EmbeddableError> { + pub fn document(mut self, document: D) -> Result<Self, D::Error> { let embed_targets = document.embeddable()?; self.documents.push((document, embed_targets)); @@ -95,7 +98,7 @@ impl<M: EmbeddingModel, D: Embeddable<Kind = K>, K: EmbeddingKind> EmbeddingsBui } /// Add many documents that implement `Embeddable` to the builder. - pub fn documents(mut self, documents: Vec<D>) -> Result<Self, EmbeddableError> { + pub fn documents(mut self, documents: Vec<D>) -> Result<Self, D::Error> { for doc in documents.into_iter() { let embed_targets = doc.embeddable()?; diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index 5faa8b2d..b46ed28c 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -1,4 +1,22 @@ //! The module defines the [Embeddable] trait, which must be implemented for types that can be embedded. +//! //! # Example +//! ```rust +//! use std::env; +//! +//! use rig::Embeddable; +//! use serde::{Deserialize, Serialize}; +//! +//! #[derive(Embeddable)] +//! struct FakeDefinition { +//! id: String, +//! word: String, +//! #[embed] +//! definitions: Vec<String>, +//! } +//! +//! // Do something with FakeDefinition +//! // ... +//! ``` /// The associated type `Kind` on the trait `Embeddable` must implement this trait. pub trait EmbeddingKind {} @@ -11,6 +29,8 @@ impl EmbeddingKind for SingleEmbedding {} pub struct ManyEmbedding; impl EmbeddingKind for ManyEmbedding {} +/// Error type used for when the `embeddable` method fails. +/// Used by default implementations of `Embeddable` for common types. #[derive(Debug, thiserror::Error)] pub enum EmbeddableError { #[error("SerdeError: {0}")] @@ -20,10 +40,12 @@ pub enum EmbeddableError { /// Trait for types that can be embedded. /// The `embeddable` method returns a list of strings for which embeddings will be generated by the embeddings builder. /// If the type `Kind` is `SingleEmbedding`, the list of strings contains a single item, otherwise, the list can contain many items. +/// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. pub trait Embeddable { type Kind: EmbeddingKind; + type Error: std::error::Error; - fn embeddable(&self) -> Result<Vec<String>, EmbeddableError>; + fn embeddable(&self) -> Result<Vec<String>, Self::Error>; } ////////////////////////////////////////////////////// @@ -31,88 +53,99 @@ pub trait Embeddable { ////////////////////////////////////////////////////// impl Embeddable for String { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.clone()]) } } impl Embeddable for i8 { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for i16 { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for i32 { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for i64 { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for i128 { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for f32 { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for f64 { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for bool { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for char { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for serde_json::Value { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(vec![ serde_json::to_string(self).map_err(EmbeddableError::SerdeError)? ]) @@ -121,8 +154,9 @@ impl Embeddable for serde_json::Value { impl<T: Embeddable> Embeddable for Vec<T> { type Kind = ManyEmbedding; + type Error = T::Error; - fn embeddable(&self) -> Result<Vec<String>, EmbeddableError> { + fn embeddable(&self) -> Result<Vec<String>, Self::Error> { Ok(self .iter() .map(|i| i.embeddable()) diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index d6f05ffc..e92896b8 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -327,6 +327,9 @@ impl ToolSet { Ok(docs) } + /// Convert tools in self to objects of type EmbeddableTool. + /// This is necessary because when adding tools to the EmbeddingBuilder because all + /// documents added to the builder must all be of the same type. pub fn embedabble_tools(&self) -> Result<Vec<EmbeddableTool>, EmbeddableError> { self.tools .values() From a7dbf6cd27be89b100144e261e40f38a9834fbb8 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Tue, 15 Oct 2024 16:06:13 -0400 Subject: [PATCH 30/91] style: cargo clippy --- rig-core/rig-core-derive/src/embeddable.rs | 1 - rig-core/src/embeddings/builder.rs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index e3fe7f94..7c3c883e 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -68,7 +68,6 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu } } }; - eprintln!("Generated code:\n{}", gen); Ok(gen) } diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 435ecc2f..ea525e8b 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -68,7 +68,7 @@ use std::{cmp::max, collections::HashMap, marker::PhantomData}; use futures::{stream, StreamExt, TryStreamExt}; use super::{ - embeddable::{Embeddable, EmbeddableError, EmbeddingKind, ManyEmbedding, SingleEmbedding}, + embeddable::{Embeddable, EmbeddingKind, ManyEmbedding, SingleEmbedding}, embedding::{Embedding, EmbeddingError, EmbeddingModel}, }; From 886ebcb6ba9c1b1c2211dad33edc8a669966b615 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Tue, 15 Oct 2024 16:11:11 -0400 Subject: [PATCH 31/91] clippy(lancedb): fix unused function error --- rig-lancedb/examples/fixtures/lib.rs | 11 ++--------- rig-lancedb/examples/vector_search_local_ann.rs | 7 +++++-- rig-lancedb/examples/vector_search_s3_ann.rs | 7 +++++-- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 956ace1c..415422f0 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -8,9 +8,9 @@ use serde::Deserialize; #[derive(Embeddable, Clone, Deserialize, Debug)] pub struct FakeDefinition { - id: String, + pub id: String, #[embed] - definition: String, + pub definition: String, } pub fn fake_definitions() -> Vec<FakeDefinition> { @@ -30,13 +30,6 @@ pub fn fake_definitions() -> Vec<FakeDefinition> { ] } -pub fn fake_definition(id: String) -> FakeDefinition { - FakeDefinition { - id, - definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() - } -} - // Schema of table in LanceDB. pub fn schema(dims: usize) -> Schema { Schema::new(Fields::from(vec![ diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 1042e34c..1b7870fb 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDefinition}; +use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ @@ -31,7 +31,10 @@ async fn main() -> Result<(), anyhow::Error> { // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) - .map(|i| fake_definition(format!("doc{}", i))) + .map(|i| FakeDefinition { + id: format!("doc{}", i), + definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() + }) .collect(), )? .build() diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 8d65e37a..8c10409b 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDefinition}; +use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, @@ -37,7 +37,10 @@ async fn main() -> Result<(), anyhow::Error> { // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) - .map(|i| fake_definition(format!("doc{}", i))) + .map(|i| FakeDefinition { + id: format!("doc{}", i), + definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() + }) .collect(), )? .build() From 79dea4536c1ed96b60f49d8234d96445e9c2e9c9 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Tue, 15 Oct 2024 16:13:25 -0400 Subject: [PATCH 32/91] fix(test): remove useless assert false statement --- rig-core/src/embeddings/embeddable.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index b46ed28c..c7b7bf0c 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -242,9 +242,7 @@ mod tests { assert_eq!( fake_definition.embeddable().unwrap(), vec!["a building in which people live; residence for human beings.".to_string()] - ); - - assert!(false) + ) } #[derive(Embeddable)] From 636234458755ca265e324fa1ab252d40e99232ab Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 16 Oct 2024 14:32:00 -0400 Subject: [PATCH 33/91] cleanup: split up branch into 2 branches for readability --- Cargo.lock | 2 - rig-core/examples/calculator_chatbot.rs | 17 +- rig-core/examples/rag.rs | 51 +-- rig-core/examples/rag_dynamic_tools.rs | 13 +- rig-core/examples/vector_search.rs | 59 +--- rig-core/examples/vector_search_cohere.rs | 62 +--- rig-core/src/embeddings/builder.rs | 332 ++++++++++-------- rig-core/src/embeddings/mod.rs | 3 +- rig-core/src/embeddings/tool.rs | 25 -- rig-core/src/providers/cohere.rs | 10 +- rig-core/src/providers/openai.rs | 18 +- rig-core/src/tool.rs | 18 +- rig-lancedb/Cargo.toml | 1 - rig-lancedb/examples/fixtures/lib.rs | 56 ++- .../examples/vector_search_local_ann.rs | 33 +- .../examples/vector_search_local_enn.rs | 9 +- rig-lancedb/examples/vector_search_s3_ann.rs | 33 +- rig-mongodb/Cargo.toml | 2 - rig-mongodb/examples/vector_search_mongodb.rs | 82 +---- rig-mongodb/src/lib.rs | 39 +- 20 files changed, 350 insertions(+), 515 deletions(-) delete mode 100644 rig-core/src/embeddings/tool.rs diff --git a/Cargo.lock b/Cargo.lock index c47f71b8..75f67709 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4026,7 +4026,6 @@ dependencies = [ "futures", "lancedb", "rig-core", - "rig-derive", "serde", "serde_json", "tokio", @@ -4040,7 +4039,6 @@ dependencies = [ "futures", "mongodb", "rig-core", - "rig-derive", "serde", "serde_json", "tokio", diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 0b994265..949073b3 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -2,7 +2,7 @@ use anyhow::Result; use rig::{ cli_chatbot::cli_chatbot, completion::ToolDefinition, - embeddings::builder::EmbeddingsBuilder, + embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, @@ -25,7 +25,7 @@ struct MathError; #[error("Init error")] struct InitError; -#[derive(Deserialize, Serialize, Clone)] +#[derive(Deserialize, Serialize)] struct Add; impl Tool for Add { const NAME: &'static str = "add"; @@ -77,7 +77,7 @@ impl ToolEmbedding for Add { fn context(&self) -> Self::Context {} } -#[derive(Deserialize, Serialize, Clone)] +#[derive(Deserialize, Serialize)] struct Subtract; impl Tool for Subtract { const NAME: &'static str = "subtract"; @@ -247,7 +247,7 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(toolset.embedabble_tools()?)? + .tools(&toolset)? .build() .await?; @@ -255,8 +255,13 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .enumerate() - .map(|(i, (tool, embedding))| (i.to_string(), tool, vec![embedding])) + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) .collect(), )? .index(embedding_model); diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 43270a7b..936a3d05 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -1,22 +1,11 @@ -use std::{env, vec}; +use std::env; use rig::{ completion::Prompt, - embeddings::builder::EmbeddingsBuilder, + embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, - Embeddable, }; -use serde::Serialize; - -// Shape of data that needs to be RAG'ed. -// The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Debug, Serialize, Eq, PartialEq, Default)] -struct FakeDefinition { - id: String, - #[embed] - definitions: Vec<String>, -} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -27,29 +16,9 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(vec![ - FakeDefinition { - id: "doc0".to_string(), - definitions: vec![ - "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets.".to_string(), - "Definition of a *flurbo*: A fictional digital currency that originated in the animated series Rick and Morty.".to_string() - ] - }, - FakeDefinition { - id: "doc1".to_string(), - definitions: vec![ - "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "Definition of a *glarb-glarb*: A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() - ] - }, - FakeDefinition { - id: "doc2".to_string(), - definitions: vec![ - "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), - "Definition of a *linglingdong*: A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() - ] - }, - ])? + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") .build() .await?; @@ -57,9 +26,13 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) .collect(), )? .index(embedding_model); diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 51c56ca8..f00543a2 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -1,7 +1,7 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, - embeddings::builder::EmbeddingsBuilder, + embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, @@ -156,7 +156,7 @@ async fn main() -> Result<(), anyhow::Error> { .build(); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(toolset.embedabble_tools()?)? + .tools(&toolset)? .build() .await?; @@ -164,8 +164,13 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .enumerate() - .map(|(i, (tool, embedding))| (i.to_string(), tool, vec![embedding])) + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) .collect(), )? .index(embedding_model); diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index a97ef8b0..26692706 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,22 +1,10 @@ use std::env; use rig::{ - embeddings::builder::EmbeddingsBuilder, + embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - Embeddable, }; -use serde::{Deserialize, Serialize}; - -// Shape of data that needs to be RAG'ed. -// The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -struct FakeDefinition { - id: String, - word: String, - #[embed] - definitions: Vec<String>, -} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -27,32 +15,9 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(vec![ - FakeDefinition { - id: "doc0".to_string(), - word: "flurbo".to_string(), - definitions: vec![ - "A green alien that lives on cold planets.".to_string(), - "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() - ] - }, - FakeDefinition { - id: "doc1".to_string(), - word: "glarb-glarb".to_string(), - definitions: vec![ - "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() - ] - }, - FakeDefinition { - id: "doc2".to_string(), - word: "linglingdong".to_string(), - definitions: vec![ - "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), - "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() - ] - }, - ])? + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") .build() .await?; @@ -60,24 +25,28 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) .collect(), )? .index(model); let results = index - .top_n::<FakeDefinition>("I need to buy something in a fictional universe. What type of money can I use for this?", 1) + .top_n::<String>("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.word)) + .map(|(score, id, doc)| (score, id, doc)) .collect::<Vec<_>>(); println!("Results: {:?}", results); let id_results = index - .top_n_ids("I need to buy something in a fictional universe. What type of money can I use for this?", 1) + .top_n_ids("What is a linglingdong?", 1) .await? .into_iter() .map(|(score, id)| (score, id)) diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 16ddb775..c14fe0ce 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,22 +1,10 @@ use std::env; use rig::{ - embeddings::builder::EmbeddingsBuilder, + embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - Embeddable, }; -use serde::{Deserialize, Serialize}; - -// Shape of data that needs to be RAG'ed. -// The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -struct FakeDefinition { - id: String, - word: String, - #[embed] - definitions: Vec<String>, -} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -27,33 +15,10 @@ async fn main() -> Result<(), anyhow::Error> { let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document"); let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query"); - let embeddings = EmbeddingsBuilder::new(document_model.clone()) - .documents(vec![ - FakeDefinition { - id: "doc0".to_string(), - word: "flurbo".to_string(), - definitions: vec![ - "A green alien that lives on cold planets.".to_string(), - "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() - ] - }, - FakeDefinition { - id: "doc1".to_string(), - word: "glarb-glarb".to_string(), - definitions: vec![ - "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() - ] - }, - FakeDefinition { - id: "doc2".to_string(), - word: "linglingdong".to_string(), - definitions: vec![ - "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), - "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() - ] - }, - ])? + let embeddings = EmbeddingsBuilder::new(document_model) + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") .build() .await?; @@ -61,21 +26,22 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) .collect(), )? .index(search_model); let results = index - .top_n::<FakeDefinition>( - "Which instrument is found in the Nebulon Mountain Ranges?", - 1, - ) + .top_n::<String>("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.word)) + .map(|(score, id, doc)| (score, id, doc)) .collect::<Vec<_>>(); println!("Results: {:?}", results); diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index ea525e8b..50e56537 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -3,207 +3,233 @@ //! //! # Example //! ```rust -//! use std::env; +//! use rig::providers::openai::{Client, self}; +//! use rig::embeddings::{EmbeddingModel, EmbeddingsBuilder}; //! -//! use rig::{ -//! embeddings::builder::EmbeddingsBuilder, -//! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, -//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, -//! Embeddable, -//! }; -//! use serde::{Deserialize, Serialize}; +//! // Initialize the OpenAI client +//! let openai = Client::new("your-openai-api-key"); //! -//! // Shape of data that needs to be RAG'ed. -//! // The definition field will be used to generate embeddings. -//! #[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -//! struct FakeDefinition { -//! id: String, -//! word: String, -//! #[embed] -//! definitions: Vec<String>, -//! } +//! // Create an instance of the `text-embedding-ada-002` model +//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002); //! -//! // Create OpenAI client -//! let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); -//! let openai_client = Client::new(&openai_api_key); -//! -//! let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); -//! -//! let embeddings = EmbeddingsBuilder::new(model.clone()) -//! .documents(vec![ -//! FakeDefinition { -//! id: "doc0".to_string(), -//! word: "flurbo".to_string(), -//! definitions: vec![ -//! "A green alien that lives on cold planets.".to_string(), -//! "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() -//! ] -//! }, -//! FakeDefinition { -//! id: "doc1".to_string(), -//! word: "glarb-glarb".to_string(), -//! definitions: vec![ -//! "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), -//! "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() -//! ] -//! }, -//! FakeDefinition { -//! id: "doc2".to_string(), -//! word: "linglingdong".to_string(), -//! definitions: vec![ -//! "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), -//! "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() -//! ] -//! }, -//! ])? +//! // Create an embeddings builder and add documents +//! let embeddings = EmbeddingsBuilder::new(embedding_model) +//! .simple_document("doc1", "This is the first document.") +//! .simple_document("doc2", "This is the second document.") //! .build() -//! .await?; +//! .await +//! .expect("Failed to build embeddings."); //! //! // Use the generated embeddings //! // ... //! ``` -use std::{cmp::max, collections::HashMap, marker::PhantomData}; +use std::{cmp::max, collections::HashMap}; use futures::{stream, StreamExt, TryStreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::tool::{ToolEmbedding, ToolSet, ToolType}; + +use super::embedding::{ Embedding, EmbeddingError, EmbeddingModel}; + +/// Struct that holds a document and its embeddings. +/// +/// The struct is designed to model any kind of documents that can be serialized to JSON +/// (including a simple string). +/// +/// Moreover, it can hold multiple embeddings for the same document, thus allowing a +/// large document to be retrieved from a query that matches multiple smaller and +/// distinct text documents. For example, if the document is a textbook, a summary of +/// each chapter could serve as the book's embeddings. +#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)] +pub struct DocumentEmbeddings { + #[serde(rename = "_id")] + pub id: String, + pub document: serde_json::Value, + pub embeddings: Vec<Embedding>, +} -use super::{ - embeddable::{Embeddable, EmbeddingKind, ManyEmbedding, SingleEmbedding}, - embedding::{Embedding, EmbeddingError, EmbeddingModel}, -}; +type Embeddings = Vec<DocumentEmbeddings>; -/// Builder for creating a collection of embeddings. -pub struct EmbeddingsBuilder<M: EmbeddingModel, D: Embeddable, K: EmbeddingKind> { - kind: PhantomData<K>, +/// Builder for creating a collection of embeddings +pub struct EmbeddingsBuilder<M: EmbeddingModel> { model: M, - documents: Vec<(D, Vec<String>)>, + documents: Vec<(String, serde_json::Value, Vec<String>)>, } -impl<M: EmbeddingModel, D: Embeddable<Kind = K>, K: EmbeddingKind> EmbeddingsBuilder<M, D, K> { +impl<M: EmbeddingModel> EmbeddingsBuilder<M> { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { - kind: PhantomData, model, documents: vec![], } } - /// Add a document that implements `Embeddable` to the builder. - pub fn document(mut self, document: D) -> Result<Self, D::Error> { - let embed_targets = document.embeddable()?; + /// Add a simple document to the embedding collection. + /// The provided document string will be used for the embedding. + pub fn simple_document(mut self, id: &str, document: &str) -> Self { + self.documents.push(( + id.to_string(), + serde_json::Value::String(document.to_string()), + vec![document.to_string()], + )); + self + } - self.documents.push((document, embed_targets)); - Ok(self) + /// Add multiple documents to the embedding collection. + /// Each element of the vector is a tuple of the form (id, document). + pub fn simple_documents(mut self, documents: Vec<(String, String)>) -> Self { + self.documents + .extend(documents.into_iter().map(|(id, document)| { + ( + id, + serde_json::Value::String(document.clone()), + vec![document], + ) + })); + self } - /// Add many documents that implement `Embeddable` to the builder. - pub fn documents(mut self, documents: Vec<D>) -> Result<Self, D::Error> { - for doc in documents.into_iter() { - let embed_targets = doc.embeddable()?; + /// Add a tool to the embedding collection. + /// The `tool.context()` corresponds to the document being stored while + /// `tool.embedding_docs()` corresponds to the documents that will be used to generate the embeddings. + pub fn tool(mut self, tool: impl ToolEmbedding + 'static) -> Result<Self, EmbeddingError> { + self.documents.push(( + tool.name(), + serde_json::to_value(tool.context())?, + tool.embedding_docs(), + )); + Ok(self) + } - self.documents.push((doc, embed_targets)); + /// Add the tools from the given toolset to the embedding collection. + pub fn tools(mut self, toolset: &ToolSet) -> Result<Self, EmbeddingError> { + for (name, tool) in toolset.tools.iter() { + if let ToolType::Embedding(tool) = tool { + self.documents.push(( + name.clone(), + tool.context().map_err(|e| { + EmbeddingError::DocumentError(format!( + "Failed to generate context for tool {}: {}", + name, e + )) + })?, + tool.embedding_docs(), + )); + } } - Ok(self) } -} -impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> - EmbeddingsBuilder<M, D, ManyEmbedding> -{ - /// Generate embeddings for all documents in the builder. - /// The method only applies when documents in the builder each contain multiple embedding targets. - /// Returns a vector of tuples, where the first element is the document and the second element is the vector of embeddings. - pub async fn build(&self) -> Result<Vec<(D, Vec<Embedding>)>, EmbeddingError> { - // Use this for reference later to merge a document back with its embeddings. + /// Add a document to the embedding collection. + /// `embed_documents` are the documents that will be used to generate the embeddings + /// for `document`. + pub fn document<T: Serialize>( + mut self, + id: &str, + document: T, + embed_documents: Vec<String>, + ) -> Self { + self.documents.push(( + id.to_string(), + serde_json::to_value(document).expect("Document should serialize"), + embed_documents, + )); + self + } + + /// Add multiple documents to the embedding collection. + /// Each element of the vector is a tuple of the form (id, document, embed_documents). + pub fn documents<T: Serialize>(mut self, documents: Vec<(String, T, Vec<String>)>) -> Self { + self.documents.extend( + documents + .into_iter() + .map(|(id, document, embed_documents)| { + ( + id, + serde_json::to_value(document).expect("Document should serialize"), + embed_documents, + ) + }), + ); + self + } + + /// Add a json document to the embedding collection. + pub fn json_document( + mut self, + id: &str, + document: serde_json::Value, + embed_documents: Vec<String>, + ) -> Self { + self.documents + .push((id.to_string(), document, embed_documents)); + self + } + + /// Add multiple json documents to the embedding collection. + pub fn json_documents( + mut self, + documents: Vec<(String, serde_json::Value, Vec<String>)>, + ) -> Self { + self.documents.extend(documents); + self + } + + /// Generate the embeddings for the given documents + pub async fn build(self) -> Result<Embeddings, EmbeddingError> { + // Create a temporary store for the documents let documents_map = self .documents - .clone() .into_iter() - .enumerate() - .map(|(id, (document, _))| (id, document)) + .map(|(id, document, docs)| (id, (document, docs))) .collect::<HashMap<_, _>>(); - let embeddings = stream::iter(self.documents.iter().enumerate()) - // Merge the embedding targets of each document into a single list of embedding targets. - .flat_map(|(i, (_, embed_targets))| { - stream::iter(embed_targets.iter().map(move |target| (i, target.clone()))) + let embeddings = stream::iter(documents_map.iter()) + // Flatten the documents + .flat_map(|(id, (_, docs))| { + stream::iter(docs.iter().map(|doc| (id.clone(), doc.clone()))) }) - // Chunk them into N (the emebdding API limit per request). + // Chunk them into N (the emebdding API limit per request) .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings for a chunk at a time. + // Generate the embeddings .map(|docs| async { - let (document_indices, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - + let (ids, docs): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); Ok::<_, EmbeddingError>( - document_indices - .into_iter() - .zip(self.model.embed_documents(embed_targets).await?.into_iter()) + ids.into_iter() + .zip(self.model.embed_documents(docs).await?.into_iter()) .collect::<Vec<_>>(), ) }) .boxed() // Parallelize the embeddings generation over 10 concurrent requests .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - .try_fold( - HashMap::new(), - |mut acc: HashMap<_, Vec<_>>, embeddings| async move { - embeddings.into_iter().for_each(|(i, embedding)| { - acc.entry(i).or_default().push(embedding); + .try_fold(vec![], |mut acc, mut embeddings| async move { + Ok({ + acc.append(&mut embeddings); + acc + }) + }) + .await?; + + // Assemble the DocumentEmbeddings + let mut document_embeddings: HashMap<String, DocumentEmbeddings> = HashMap::new(); + embeddings.into_iter().for_each(|(id, embedding)| { + let (document, _) = documents_map.get(&id).expect("Document not found"); + let document_embedding = + document_embeddings + .entry(id.clone()) + .or_insert_with(|| DocumentEmbeddings { + id: id.clone(), + document: document.clone(), + embeddings: vec![], }); - Ok(acc) - }, - ) - .await? - .iter() - .fold(vec![], |mut acc, (i, embeddings_vec)| { - acc.push(( - documents_map.get(i).cloned().unwrap(), - embeddings_vec.clone(), - )); - acc - }); + document_embedding.embeddings.push(embedding); + }); - Ok(embeddings) + Ok(document_embeddings.into_values().collect()) } -} - -impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> - EmbeddingsBuilder<M, D, SingleEmbedding> -{ - /// Generate embeddings for all documents in the builder. - /// The method only applies when documents in the builder each contain a single embedding target. - /// Returns a vector of tuples, where the first element is the document and the second element is the embedding. - pub async fn build(&self) -> Result<Vec<(D, Embedding)>, EmbeddingError> { - let embeddings = stream::iter( - self.documents - .clone() - .into_iter() - .map(|(document, embed_target)| (document, embed_target.first().cloned().unwrap())), - ) - // Chunk them into N (the emebdding API limit per request) - .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings - .map(|docs| async { - let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - Ok::<_, EmbeddingError>( - documents - .into_iter() - .zip(self.model.embed_documents(embed_targets).await?.into_iter()) - .collect::<Vec<_>>(), - ) - }) - .boxed() - // Parallelize the embeddings generation over 10 concurrent requests - .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - .try_fold(vec![], |mut acc, embeddings| async move { - acc.extend(embeddings); - Ok(acc) - }) - .await?; - - Ok(embeddings) - } -} +} \ No newline at end of file diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index d590b7d0..061066c1 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -5,5 +5,4 @@ pub mod builder; pub mod embeddable; -pub mod embedding; -pub mod tool; +pub mod embedding; \ No newline at end of file diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs deleted file mode 100644 index 5e8ecd35..00000000 --- a/rig-core/src/embeddings/tool.rs +++ /dev/null @@ -1,25 +0,0 @@ -use crate::{self as rig, tool::ToolEmbeddingDyn}; -use rig::embeddings::embeddable::Embeddable; -use rig_derive::Embeddable; -use serde::Serialize; - -use super::embeddable::EmbeddableError; - -/// Used by EmbeddingsBuilder to embed anything that implements ToolEmbedding. -#[derive(Embeddable, Clone, Serialize, Default, Eq, PartialEq)] -pub struct EmbeddableTool { - name: String, - #[embed] - context: serde_json::Value, -} - -impl EmbeddableTool { - /// Convert item that implements ToolEmbedding to an EmbeddableTool. - pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result<Self, EmbeddableError> { - Ok(EmbeddableTool { - name: tool.name(), - context: serde_json::to_value(tool.context().map_err(EmbeddableError::SerdeError)?) - .map_err(EmbeddableError::SerdeError)?, - }) - } -} diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index fa798205..eba4ce3d 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -13,9 +13,7 @@ use std::collections::HashMap; use crate::{ agent::AgentBuilder, completion::{self, CompletionError}, - embeddings::{ - self, builder::EmbeddingsBuilder, embeddable::Embeddable, embedding::EmbeddingError, - }, + embeddings::{self, embedding::EmbeddingError, builder::EmbeddingsBuilder}, extractor::ExtractorBuilder, json_utils, }; @@ -87,11 +85,7 @@ impl Client { EmbeddingModel::new(self.clone(), model, input_type, ndims) } - pub fn embeddings<D: Embeddable>( - &self, - model: &str, - input_type: &str, - ) -> EmbeddingsBuilder<EmbeddingModel, D, D::Kind> { + pub fn embeddings(&self, model: &str, input_type: &str) -> EmbeddingsBuilder<EmbeddingModel> { EmbeddingsBuilder::new(self.embedding_model(model, input_type)) } diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index d21bf8fb..eb373414 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -11,12 +11,7 @@ use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, - embeddings::{ - self, - builder::EmbeddingsBuilder, - embeddable::Embeddable, - embedding::{Embedding, EmbeddingError}, - }, + embeddings::{self, embedding::EmbeddingError}, extractor::ExtractorBuilder, json_utils, }; @@ -126,11 +121,8 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings<T: Embeddable>( - &self, - model: &str, - ) -> EmbeddingsBuilder<EmbeddingModel, T, T::Kind> { - EmbeddingsBuilder::new(self.embedding_model(model)) + pub fn embeddings(&self, model: &str) -> embeddings::builder::EmbeddingsBuilder<EmbeddingModel> { + embeddings::builder::EmbeddingsBuilder::new(self.embedding_model(model)) } /// Create a completion model with the given name. @@ -250,7 +242,7 @@ impl embeddings::embedding::EmbeddingModel for EmbeddingModel { async fn embed_documents( &self, documents: Vec<String>, - ) -> Result<Vec<Embedding>, EmbeddingError> { + ) -> Result<Vec<embeddings::embedding::Embedding>, EmbeddingError> { let response = self .client .post("/v1/embeddings") @@ -276,7 +268,7 @@ impl embeddings::embedding::EmbeddingModel for EmbeddingModel { .data .into_iter() .zip(documents.into_iter()) - .map(|(embedding, document)| Embedding { + .map(|(embedding, document)| embeddings::embedding::Embedding { document, vec: embedding.embedding, }) diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index e92896b8..4e5e7675 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::{ completion::{self, ToolDefinition}, - embeddings::{embeddable::EmbeddableError, tool::EmbeddableTool}, + embeddings::{embeddable::EmbeddableError}, }; #[derive(Debug, thiserror::Error)] @@ -326,22 +326,6 @@ impl ToolSet { } Ok(docs) } - - /// Convert tools in self to objects of type EmbeddableTool. - /// This is necessary because when adding tools to the EmbeddingBuilder because all - /// documents added to the builder must all be of the same type. - pub fn embedabble_tools(&self) -> Result<Vec<EmbeddableTool>, EmbeddableError> { - self.tools - .values() - .filter_map(|tool_type| { - if let ToolType::Embedding(tool) = tool_type { - Some(EmbeddableTool::try_from(&**tool)) - } else { - None - } - }) - .collect::<Result<Vec<_>, _>>() - } } #[derive(Default)] diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index 6ee41d8d..031df2d3 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -6,7 +6,6 @@ edition = "2021" [dependencies] lancedb = "0.10.0" rig-core = { path = "../rig-core", version = "0.2.1" } -rig-derive = { path = "../rig-core/rig-core-derive" } arrow-array = "52.2.0" serde_json = "1.0.128" serde = "1.0.210" diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 415422f0..d95a42e4 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -2,39 +2,13 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; -use rig::embeddings::embedding::Embedding; -use rig::Embeddable; -use serde::Deserialize; - -#[derive(Embeddable, Clone, Deserialize, Debug)] -pub struct FakeDefinition { - pub id: String, - #[embed] - pub definition: String, -} - -pub fn fake_definitions() -> Vec<FakeDefinition> { - vec![ - FakeDefinition { - id: "doc0".to_string(), - definition: "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.".to_string() - }, - FakeDefinition { - id: "doc1".to_string(), - definition: "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.".to_string() - }, - FakeDefinition { - id: "doc2".to_string(), - definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() - } - ] -} +use rig::embeddings::builder::DocumentEmbeddings; // Schema of table in LanceDB. pub fn schema(dims: usize) -> Schema { Schema::new(Fields::from(vec![ Field::new("id", DataType::Utf8, false), - Field::new("definition", DataType::Utf8, false), + Field::new("content", DataType::Utf8, false), Field::new( "embedding", DataType::FixedSizeList( @@ -46,36 +20,48 @@ pub fn schema(dims: usize) -> Schema { ])) } -// Convert FakeDefinition objects and their embedding to a RecordBatch. +// Convert DocumentEmbeddings objects to a RecordBatch. pub fn as_record_batch( - records: Vec<(FakeDefinition, Embedding)>, + records: Vec<DocumentEmbeddings>, dims: usize, ) -> Result<RecordBatch, lancedb::arrow::arrow_schema::ArrowError> { let id = StringArray::from_iter_values( records .iter() - .map(|(FakeDefinition { id, .. }, _)| id) + .flat_map(|record| (0..record.embeddings.len()).map(|i| format!("{}-{i}", record.id))) .collect::<Vec<_>>(), ); - let definition = StringArray::from_iter_values( + let content = StringArray::from_iter_values( records .iter() - .map(|(FakeDefinition { definition, .. }, _)| definition) + .flat_map(|record| { + record + .embeddings + .iter() + .map(|embedding| embedding.document.clone()) + }) .collect::<Vec<_>>(), ); let embedding = FixedSizeListArray::from_iter_primitive::<Float64Type, _, _>( records .into_iter() - .map(|(_, Embedding { vec, .. })| Some(vec.into_iter().map(Some).collect::<Vec<_>>())) + .flat_map(|record| { + record + .embeddings + .into_iter() + .map(|embedding| embedding.vec.into_iter().map(Some).collect::<Vec<_>>()) + .map(Some) + .collect::<Vec<_>>() + }) .collect::<Vec<_>>(), dims as i32, ); RecordBatch::try_from_iter(vec![ ("id", Arc::new(id) as ArrayRef), - ("definition", Arc::new(definition) as ArrayRef), + ("content", Arc::new(content) as ArrayRef), ("embedding", Arc::new(embedding) as ArrayRef), ]) } diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 1b7870fb..a9744210 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,18 +1,25 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; +use fixture::{as_record_batch, schema}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{embedding::EmbeddingModel, builder::EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use serde::Deserialize; #[path = "./fixtures/lib.rs"] mod fixture; +#[derive(Deserialize, Debug)] +pub struct VectorSearchResult { + pub id: String, + pub content: String, +} + #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). @@ -25,18 +32,18 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; + // Set up test data for RAG demo + let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); + + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + let definitions = vec![definition; 256]; + // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions())? - // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. - .documents( - (0..256) - .map(|i| FakeDefinition { - id: format!("doc{}", i), - definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() - }) - .collect(), - )? + .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") + .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") + .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") + .simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect()) .build() .await?; @@ -65,7 +72,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n::<FakeDefinition>("My boss says I zindle too much, what does that mean?", 1) + .top_n::<VectorSearchResult>("My boss says I zindle too much, what does that mean?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 630acc1a..69a8599b 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -1,9 +1,9 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema}; +use fixture::{as_record_batch, schema}; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{embedding::EmbeddingModel, builder::EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, }; @@ -21,9 +21,10 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions())? + .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") + .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") + .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") .build() .await?; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 8c10409b..9197fc42 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,18 +1,25 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; +use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{embedding::EmbeddingModel, builder::EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use serde::Deserialize; #[path = "./fixtures/lib.rs"] mod fixture; +#[derive(Deserialize, Debug)] +pub struct VectorSearchResult { + pub id: String, + pub content: String, +} + // Note: see docs to deploy LanceDB on other cloud providers such as google and azure. // https://lancedb.github.io/lancedb/guides/storage/ #[tokio::main] @@ -31,18 +38,18 @@ async fn main() -> Result<(), anyhow::Error> { .execute() .await?; + // Set up test data for RAG demo + let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); + + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + let definitions = vec![definition; 256]; + // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions())? - // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. - .documents( - (0..256) - .map(|i| FakeDefinition { - id: format!("doc{}", i), - definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() - }) - .collect(), - )? + .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") + .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") + .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") + .simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect()) .build() .await?; @@ -77,7 +84,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n::<FakeDefinition>("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) + .top_n::<VectorSearchResult>("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-mongodb/Cargo.toml b/rig-mongodb/Cargo.toml index 78b48892..6f313838 100644 --- a/rig-mongodb/Cargo.toml +++ b/rig-mongodb/Cargo.toml @@ -13,8 +13,6 @@ repository = "https://github.com/0xPlaygrounds/rig" futures = "0.3.30" mongodb = "2.8.2" rig-core = { path = "../rig-core", version = "0.2.1" } -rig-derive = { path = "../rig-core/rig-core-derive" } - serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.117" tracing = "0.1.40" diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index a816c9c0..cca7a6d4 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,39 +1,13 @@ -use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; -use rig::providers::openai::TEXT_EMBEDDING_ADA_002; -use serde::{Deserialize, Serialize}; +use mongodb::{options::ClientOptions, Client as MongoClient, Collection}; use std::env; -use rig::Embeddable; use rig::{ - embeddings::builder::EmbeddingsBuilder, providers::openai::Client, + embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; -// Shape of data that needs to be RAG'ed. -// The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug)] -struct FakeDefinition { - id: String, - #[embed] - definition: String, -} - -#[derive(Clone, Deserialize, Debug, Serialize)] -struct Link { - word: String, - link: String, -} - -// Shape of the document to be stored in MongoDB, with embeddings. -#[derive(Serialize, Debug)] -struct Document { - #[serde(rename = "_id")] - id: String, - definition: String, - embedding: Vec<f64>, -} - #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client @@ -51,61 +25,37 @@ async fn main() -> Result<(), anyhow::Error> { MongoClient::with_options(options).expect("MongoDB client options should be valid"); // Initialize MongoDB vector store - let collection: Collection<Document> = mongodb_client + let collection: Collection<DocumentEmbeddings> = mongodb_client .database("knowledgebase") .collection("context"); // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - let fake_definitions = vec![ - FakeDefinition { - id: "doc0".to_string(), - definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), - }, - FakeDefinition { - id: "doc1".to_string(), - definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - }, - FakeDefinition { - id: "doc2".to_string(), - definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), - } - ]; - let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions)? + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") .build() .await?; - let mongo_documents = embeddings - .iter() - .map( - |(FakeDefinition { id, definition, .. }, embedding)| Document { - id: id.clone(), - definition: definition.clone(), - embedding: embedding.vec.clone(), - }, - ) - .collect::<Vec<_>>(); - - match collection.insert_many(mongo_documents, None).await { + match collection.insert_many(embeddings, None).await { Ok(_) => println!("Documents added successfully"), Err(e) => println!("Error adding documents: {:?}", e), - }; + } // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = MongoDbVectorStore::new(collection).index( - model, - "definitions_vector_index", - SearchParams::new("embedding"), - ); + let index = + MongoDbVectorStore::new(collection).index(model, "vector_index", SearchParams::default()); // Query the index let results = index - .top_n::<FakeDefinition>("What is a linglingdong?", 1) - .await?; + .top_n::<DocumentEmbeddings>("What is a linglingdong?", 1) + .await? + .into_iter() + .map(|(score, id, doc)| (score, id, doc.document)) + .collect::<Vec<_>>(); println!("Results: {:?}", results); diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 4778e454..17dda463 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -2,23 +2,23 @@ use futures::StreamExt; use mongodb::bson::{self, doc}; use rig::{ - embeddings::embedding::{Embedding, EmbeddingModel}, + embeddings::{builder::DocumentEmbeddings, embedding::Embedding, embedding::EmbeddingModel}, vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; /// A MongoDB vector store. -pub struct MongoDbVectorStore<C> { - collection: mongodb::Collection<C>, +pub struct MongoDbVectorStore { + collection: mongodb::Collection<DocumentEmbeddings>, } fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } -impl<C> MongoDbVectorStore<C> { +impl MongoDbVectorStore { /// Create a new `MongoDbVectorStore` from a MongoDB collection. - pub fn new(collection: mongodb::Collection<C>) -> Self { + pub fn new(collection: mongodb::Collection<DocumentEmbeddings>) -> Self { Self { collection } } @@ -31,20 +31,20 @@ impl<C> MongoDbVectorStore<C> { model: M, index_name: &str, search_params: SearchParams, - ) -> MongoDbVectorIndex<M, C> { + ) -> MongoDbVectorIndex<M> { MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params) } } /// A vector index for a MongoDB collection. -pub struct MongoDbVectorIndex<M: EmbeddingModel, C> { - collection: mongodb::Collection<C>, +pub struct MongoDbVectorIndex<M: EmbeddingModel> { + collection: mongodb::Collection<DocumentEmbeddings>, model: M, index_name: String, search_params: SearchParams, } -impl<M: EmbeddingModel, C> MongoDbVectorIndex<M, C> { +impl<M: EmbeddingModel> MongoDbVectorIndex<M> { /// Vector search stage of aggregation pipeline of mongoDB collection. /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex. fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document { @@ -52,13 +52,12 @@ impl<M: EmbeddingModel, C> MongoDbVectorIndex<M, C> { filter, exact, num_candidates, - path, } = &self.search_params; doc! { "$vectorSearch": { "index": &self.index_name, - "path": path, + "path": "embeddings.vec", "queryVector": &prompt_embedding.vec, "numCandidates": num_candidates.unwrap_or((n * 10) as u32), "limit": n as u32, @@ -79,9 +78,9 @@ impl<M: EmbeddingModel, C> MongoDbVectorIndex<M, C> { } } -impl<M: EmbeddingModel, C> MongoDbVectorIndex<M, C> { +impl<M: EmbeddingModel> MongoDbVectorIndex<M> { pub fn new( - collection: mongodb::Collection<C>, + collection: mongodb::Collection<DocumentEmbeddings>, model: M, index_name: &str, search_params: SearchParams, @@ -99,19 +98,17 @@ impl<M: EmbeddingModel, C> MongoDbVectorIndex<M, C> { /// on each of the fields pub struct SearchParams { filter: mongodb::bson::Document, - path: String, exact: Option<bool>, num_candidates: Option<u32>, } impl SearchParams { /// Initializes a new `SearchParams` with default values. - pub fn new(path: &str) -> Self { + pub fn new() -> Self { Self { filter: doc! {}, exact: None, num_candidates: None, - path: path.to_string(), } } @@ -141,9 +138,13 @@ impl SearchParams { } } -impl<M: EmbeddingModel + std::marker::Sync + Send, C: std::marker::Sync + Send> VectorStoreIndex - for MongoDbVectorIndex<M, C> -{ +impl Default for SearchParams { + fn default() -> Self { + Self::new() + } +} + +impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for MongoDbVectorIndex<M> { async fn top_n<T: for<'a> Deserialize<'a> + std::marker::Send>( &self, query: &str, From b5e1bf3a505d0ecde421ef9bef72bc499e800a35 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 16 Oct 2024 14:32:21 -0400 Subject: [PATCH 34/91] cleanup: revert certain changes during branch split --- rig-core/src/embeddings/builder.rs | 4 ++-- rig-core/src/embeddings/mod.rs | 2 +- rig-core/src/providers/cohere.rs | 2 +- rig-core/src/providers/openai.rs | 5 ++++- rig-core/src/tool.rs | 5 +---- rig-lancedb/examples/vector_search_local_ann.rs | 2 +- rig-lancedb/examples/vector_search_local_enn.rs | 2 +- rig-lancedb/examples/vector_search_s3_ann.rs | 2 +- 8 files changed, 12 insertions(+), 12 deletions(-) diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 50e56537..afbe2056 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -31,7 +31,7 @@ use serde::{Deserialize, Serialize}; use crate::tool::{ToolEmbedding, ToolSet, ToolType}; -use super::embedding::{ Embedding, EmbeddingError, EmbeddingModel}; +use super::embedding::{Embedding, EmbeddingError, EmbeddingModel}; /// Struct that holds a document and its embeddings. /// @@ -232,4 +232,4 @@ impl<M: EmbeddingModel> EmbeddingsBuilder<M> { Ok(document_embeddings.into_values().collect()) } -} \ No newline at end of file +} diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 061066c1..37e720cb 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -5,4 +5,4 @@ pub mod builder; pub mod embeddable; -pub mod embedding; \ No newline at end of file +pub mod embedding; diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index eba4ce3d..87f2334d 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -13,7 +13,7 @@ use std::collections::HashMap; use crate::{ agent::AgentBuilder, completion::{self, CompletionError}, - embeddings::{self, embedding::EmbeddingError, builder::EmbeddingsBuilder}, + embeddings::{self, builder::EmbeddingsBuilder, embedding::EmbeddingError}, extractor::ExtractorBuilder, json_utils, }; diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index eb373414..9adf6680 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -121,7 +121,10 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings(&self, model: &str) -> embeddings::builder::EmbeddingsBuilder<EmbeddingModel> { + pub fn embeddings( + &self, + model: &str, + ) -> embeddings::builder::EmbeddingsBuilder<EmbeddingModel> { embeddings::builder::EmbeddingsBuilder::new(self.embedding_model(model)) } diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 4e5e7675..98394ecf 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -3,10 +3,7 @@ use std::{collections::HashMap, pin::Pin}; use futures::Future; use serde::{Deserialize, Serialize}; -use crate::{ - completion::{self, ToolDefinition}, - embeddings::{embeddable::EmbeddableError}, -}; +use crate::completion::{self, ToolDefinition}; #[derive(Debug, thiserror::Error)] pub enum ToolError { diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index a9744210..52ee213a 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -5,7 +5,7 @@ use fixture::{as_record_batch, schema}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::{embedding::EmbeddingModel, builder::EmbeddingsBuilder}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 69a8599b..9f0ec934 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -3,7 +3,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use rig::{ - embeddings::{embedding::EmbeddingModel, builder::EmbeddingsBuilder}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, }; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 9197fc42..7ce61309 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,7 +4,7 @@ use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::{embedding::EmbeddingModel, builder::EmbeddingsBuilder}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; From 7caf134db193a9890c612f6bf61d08c6b6512cf5 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 16 Oct 2024 14:37:44 -0400 Subject: [PATCH 35/91] docs: revert doc string --- rig-core/src/embeddings/builder.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index afbe2056..2458b98d 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -1,5 +1,6 @@ -//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. -//! Only types that implement the [Embeddable] trait can be added to the [EmbeddingsBuilder]. +//! The module provides an implementation of the [EmbeddingsBuilder] +//! struct, which allows users to build collections of document embeddings using different embedding +//! models and document sources. //! //! # Example //! ```rust From 87396929a67dedd776977d45f38de5dac549559f Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 16 Oct 2024 15:36:03 -0400 Subject: [PATCH 36/91] fix: add embedding_docs to embeddable tool --- rig-core/examples/calculator_chatbot.rs | 2 +- rig-core/examples/rag_dynamic_tools.rs | 2 +- rig-core/src/embeddings/builder.rs | 2 +- rig-core/src/embeddings/tool.rs | 6 ++++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 0b994265..813c48eb 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -256,7 +256,7 @@ async fn main() -> Result<(), anyhow::Error> { embeddings .into_iter() .enumerate() - .map(|(i, (tool, embedding))| (i.to_string(), tool, vec![embedding])) + .map(|(i, (tool, embedding))| (i.to_string(), tool, embedding)) .collect(), )? .index(embedding_model); diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 51c56ca8..cc04c209 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -165,7 +165,7 @@ async fn main() -> Result<(), anyhow::Error> { embeddings .into_iter() .enumerate() - .map(|(i, (tool, embedding))| (i.to_string(), tool, vec![embedding])) + .map(|(i, (tool, embedding))| (i.to_string(), tool, embedding)) .collect(), )? .index(embedding_model); diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index ea525e8b..ed094bed 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -130,7 +130,7 @@ impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> .flat_map(|(i, (_, embed_targets))| { stream::iter(embed_targets.iter().map(move |target| (i, target.clone()))) }) - // Chunk them into N (the emebdding API limit per request). + // Chunk them into N (the embedding API limit per request). .chunks(M::MAX_DOCUMENTS) // Generate the embeddings for a chunk at a time. .map(|docs| async { diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index 5e8ecd35..278efb20 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -9,8 +9,9 @@ use super::embeddable::EmbeddableError; #[derive(Embeddable, Clone, Serialize, Default, Eq, PartialEq)] pub struct EmbeddableTool { name: String, - #[embed] context: serde_json::Value, + #[embed] + embedding_docs: Vec<String> } impl EmbeddableTool { @@ -20,6 +21,7 @@ impl EmbeddableTool { name: tool.name(), context: serde_json::to_value(tool.context().map_err(EmbeddableError::SerdeError)?) .map_err(EmbeddableError::SerdeError)?, + embedding_docs: tool.embedding_docs(), }) } -} +} \ No newline at end of file From fb979eccbe751fa82b94fa5b486f0677010bc1e3 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 16 Oct 2024 17:59:08 -0400 Subject: [PATCH 37/91] refactor: use OneOrMany in Embbedable trait, make derive macro crate feature flag --- rig-core/Cargo.toml | 11 +- rig-core/rig-core-derive/src/embeddable.rs | 65 +---- rig-core/src/embeddings/embeddable.rs | 281 +++++++-------------- rig-core/src/lib.rs | 2 + rig-core/tests/embeddable_macro.rs | 141 +++++++++++ 5 files changed, 261 insertions(+), 239 deletions(-) create mode 100644 rig-core/tests/embeddable_macro.rs diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 658faf4b..d2ab06c6 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -23,9 +23,16 @@ futures = "0.3.29" ordered-float = "4.2.0" schemars = "0.8.16" thiserror = "1.0.61" -rig-derive = { path = "./rig-core-derive" } +rig-derive = { path = "./rig-core-derive", optional = true } [dev-dependencies] anyhow = "1.0.75" tokio = { version = "1.34.0", features = ["full"] } -tracing-subscriber = "0.3.18" \ No newline at end of file +tracing-subscriber = "0.3.18" + +[features] +rig_derive = ["dep:rig-derive"] + +[[test]] +name = "embeddable_macro" +required-features = ["rig_derive"] \ No newline at end of file diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index 7c3c883e..e36a7eea 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -1,21 +1,18 @@ use proc_macro2::TokenStream; use quote::quote; -use syn::{parse_str, DataStruct}; +use syn::DataStruct; use crate::{ basic::{add_struct_bounds, basic_embed_fields}, custom::custom_embed_fields, }; -const VEC_TYPE: &str = "Vec"; -const MANY_EMBEDDING: &str = "ManyEmbedding"; -const SINGLE_EMBEDDING: &str = "SingleEmbedding"; pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result<TokenStream> { let name = &input.ident; let data = &input.data; let generics = &mut input.generics; - let (target_stream, embed_kind) = match data { + let target_stream = match data { syn::Data::Struct(data_struct) => { let (basic_targets, basic_target_size) = data_struct.basic(generics); let (custom_targets, custom_target_size) = data_struct.custom()?; @@ -30,13 +27,10 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu } // Determine whether the Embeddable::Kind should be SingleEmbedding or ManyEmbedding - ( - quote! { - let mut embed_targets = #basic_targets; - embed_targets.extend(#custom_targets) - }, - embed_kind(data_struct)?, - ) + quote! { + let mut embed_targets = #basic_targets; + embed_targets.extend(#custom_targets) + } } _ => { return Err(syn::Error::new_spanned( @@ -52,19 +46,16 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu // Note: Embeddable trait is imported with the macro. impl #impl_generics Embeddable for #name #ty_generics #where_clause { - type Kind = rig::embeddings::embeddable::#embed_kind; type Error = rig::embeddings::embeddable::EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { + fn embeddable(&self) -> Result<rig::embeddings::embeddable::OneOrMany<String>, Self::Error> { #target_stream; - let targets = embed_targets.into_iter() - .collect::<Result<Vec<_>, _>>()? - .into_iter() - .flatten() - .collect::<Vec<_>>(); - - Ok(targets) + rig::embeddings::embeddable::OneOrMany::try_from( + embed_targets.into_iter() + .collect::<Result<Vec<_>, _>>() + .map_err(|e| rig::embeddings::embeddable::EmbeddableError::Error(e.to_string()))? + ) } } }; @@ -72,38 +63,6 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu Ok(gen) } -/// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is 1, -/// returns the kind of embedding that field should be. -/// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is greater than 1, -/// return ManyEmbedding. -fn embed_kind(data_struct: &DataStruct) -> syn::Result<syn::Expr> { - fn embed_kind(field: &syn::Field) -> syn::Result<syn::Expr> { - match &field.ty { - syn::Type::Path(path) => { - if path.path.segments.first().unwrap().ident == VEC_TYPE { - parse_str(MANY_EMBEDDING) - } else { - parse_str(SINGLE_EMBEDDING) - } - } - _ => parse_str(SINGLE_EMBEDDING), - } - } - let fields = basic_embed_fields(data_struct) - .chain( - custom_embed_fields(data_struct)? - .into_iter() - .map(|(f, _)| f), - ) - .collect::<Vec<_>>(); - - if fields.len() == 1 { - fields.iter().map(embed_kind).next().unwrap() - } else { - parse_str(MANY_EMBEDDING) - } -} - trait StructParser { // Handles fields tagged with #[embed] fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize); diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index c7b7bf0c..c4cf90da 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -35,6 +35,8 @@ impl EmbeddingKind for ManyEmbedding {} pub enum EmbeddableError { #[error("SerdeError: {0}")] SerdeError(#[from] serde_json::Error), + #[error("Error: {0}")] + Error(String), } /// Trait for types that can be embedded. @@ -42,266 +44,177 @@ pub enum EmbeddableError { /// If the type `Kind` is `SingleEmbedding`, the list of strings contains a single item, otherwise, the list can contain many items. /// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. pub trait Embeddable { - type Kind: EmbeddingKind; type Error: std::error::Error; - fn embeddable(&self) -> Result<Vec<String>, Self::Error>; + fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error>; +} + +#[derive(PartialEq, Eq, Debug)] +pub struct OneOrMany<T> { + first: T, + rest: Vec<T>, +} + +impl<T: Clone> OneOrMany<T> { + pub fn first(&self) -> T { + self.first.clone() + } + + pub fn rest(&self) -> Vec<T> { + self.rest.clone() + } + + pub fn all(&self) -> Vec<T> { + let mut all = vec![self.first.clone()]; + all.extend(self.rest.clone().into_iter()); + all + } +} + +impl<T> From<T> for OneOrMany<T> { + fn from(item: T) -> Self { + OneOrMany { + first: item, + rest: vec![], + } + } +} + +impl<T> TryFrom<Vec<T>> for OneOrMany<T> { + type Error = EmbeddableError; + + fn try_from(items: Vec<T>) -> Result<Self, Self::Error> { + let mut iter = items.into_iter(); + Ok(OneOrMany { + first: match iter.next() { + Some(item) => item, + None => { + return Err(EmbeddableError::Error(format!( + "Cannot convert empty Vec to OneOrMany" + ))) + } + }, + rest: iter.collect(), + }) + } +} + +impl<T: Clone> TryFrom<Vec<OneOrMany<T>>> for OneOrMany<T> { + type Error = EmbeddableError; + + fn try_from(value: Vec<OneOrMany<T>>) -> Result<Self, Self::Error> { + let items = value + .into_iter() + .flat_map(|one_or_many| one_or_many.all()) + .collect::<Vec<_>>(); + + OneOrMany::try_from(items) + } } ////////////////////////////////////////////////////// /// Implementations of Embeddable for common types /// ////////////////////////////////////////////////////// impl Embeddable for String { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { - Ok(vec![self.clone()]) + fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::from(self.clone())) } } impl Embeddable for i8 { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for i16 { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for i32 { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for i64 { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for i128 { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for f32 { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for f64 { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for bool { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for char { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for serde_json::Value { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { - Ok(vec![ - serde_json::to_string(self).map_err(EmbeddableError::SerdeError)? - ]) + fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::from( + serde_json::to_string(self).map_err(EmbeddableError::SerdeError)?, + )) } } impl<T: Embeddable> Embeddable for Vec<T> { - type Kind = ManyEmbedding; - type Error = T::Error; + type Error = EmbeddableError; - fn embeddable(&self) -> Result<Vec<String>, Self::Error> { - Ok(self + fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { + let items = self .iter() - .map(|i| i.embeddable()) - .collect::<Result<Vec<_>, _>>()? - .into_iter() - .flatten() - .collect()) - } -} - -#[cfg(test)] -mod tests { - use crate as rig; - use rig::embeddings::embeddable::{Embeddable, EmbeddableError}; - use rig_derive::Embeddable; - use serde::Serialize; - - fn serialize(definition: Definition) -> Result<Vec<String>, EmbeddableError> { - Ok(vec![ - serde_json::to_string(&definition).map_err(EmbeddableError::SerdeError)? - ]) - } - - #[derive(Embeddable)] - struct FakeDefinition { - id: String, - word: String, - #[embed(embed_with = "serialize")] - definition: Definition, - } - - #[derive(Serialize, Clone)] - struct Definition { - word: String, - link: String, - speech: String, - } - - #[test] - fn test_custom_embed() { - let fake_definition = FakeDefinition { - id: "doc1".to_string(), - word: "house".to_string(), - definition: Definition { - speech: "noun".to_string(), - word: "a building in which people live; residence for human beings.".to_string(), - link: "https://www.dictionary.com/browse/house".to_string(), - }, - }; - - println!( - "FakeDefinition: {}, {}", - fake_definition.id, fake_definition.word - ); - - assert_eq!( - fake_definition.embeddable().unwrap(), - vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] - ) - } - - #[derive(Embeddable)] - struct FakeDefinition2 { - id: String, - word: String, - #[embed] - definition: String, - } - - #[test] - fn test_single_embed() { - let fake_definition = FakeDefinition2 { - id: "doc1".to_string(), - word: "house".to_string(), - definition: "a building in which people live; residence for human beings.".to_string(), - }; - - println!( - "FakeDefinition2: {}, {}", - fake_definition.id, fake_definition.word - ); - - assert_eq!( - fake_definition.embeddable().unwrap(), - vec!["a building in which people live; residence for human beings.".to_string()] - ) - } - - #[derive(Embeddable)] - struct Company { - id: String, - company: String, - #[embed] - employee_ages: Vec<i32>, - } - - #[test] - fn test_multiple_embed() { - let company = Company { - id: "doc1".to_string(), - company: "Google".to_string(), - employee_ages: vec![25, 30, 35, 40], - }; - - println!("Company: {}, {}", company.id, company.company); - - assert_eq!( - company.embeddable().unwrap(), - vec![ - "25".to_string(), - "30".to_string(), - "35".to_string(), - "40".to_string() - ] - ); - } - - #[derive(Embeddable)] - struct Company2 { - id: String, - #[embed] - company: String, - #[embed] - employee_ages: Vec<i32>, - } + .map(|item| item.embeddable()) + .collect::<Result<Vec<_>, _>>() + .map_err(|e| EmbeddableError::Error(e.to_string()))?; - #[test] - fn test_many_embed() { - let company = Company2 { - id: "doc1".to_string(), - company: "Google".to_string(), - employee_ages: vec![25, 30, 35, 40], - }; - - println!("Company2: {}", company.id); - - assert_eq!( - company.embeddable().unwrap(), - vec![ - "Google".to_string(), - "25".to_string(), - "30".to_string(), - "35".to_string(), - "40".to_string() - ] - ); + OneOrMany::try_from(items) } } diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index b7f0615e..cc17efa3 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -78,4 +78,6 @@ pub mod vector_store; // Export Embeddable trait and Embeddable together. pub use embeddings::embeddable::Embeddable; + +#[cfg(feature = "rig_derive")] pub use rig_derive::Embeddable; diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/embeddable_macro.rs new file mode 100644 index 00000000..130da109 --- /dev/null +++ b/rig-core/tests/embeddable_macro.rs @@ -0,0 +1,141 @@ +use rig::embeddings::embeddable::{EmbeddableError, OneOrMany}; +use rig::Embeddable; +use serde::Serialize; + +fn serialize(definition: Definition) -> Result<OneOrMany<String>, EmbeddableError> { + Ok(OneOrMany::from( + serde_json::to_string(&definition).map_err(EmbeddableError::SerdeError)?, + )) +} + +#[derive(Embeddable)] +struct FakeDefinition { + id: String, + word: String, + #[embed(embed_with = "serialize")] + definition: Definition, +} + +#[derive(Serialize, Clone)] +struct Definition { + word: String, + link: String, + speech: String, +} + +#[test] +fn test_custom_embed() { + let fake_definition = FakeDefinition { + id: "doc1".to_string(), + word: "house".to_string(), + definition: Definition { + speech: "noun".to_string(), + word: "a building in which people live; residence for human beings.".to_string(), + link: "https://www.dictionary.com/browse/house".to_string(), + }, + }; + + println!( + "FakeDefinition: {}, {}", + fake_definition.id, fake_definition.word + ); + + assert_eq!( + fake_definition.embeddable().unwrap(), + OneOrMany::from( + "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() + ) + + ) +} + +#[derive(Embeddable)] +struct FakeDefinition2 { + id: String, + word: String, + #[embed] + definition: String, +} + +#[test] +fn test_single_embed() { + let definition = "a building in which people live; residence for human beings.".to_string(); + + let fake_definition = FakeDefinition2 { + id: "doc1".to_string(), + word: "house".to_string(), + definition: definition.clone(), + }; + + println!( + "FakeDefinition2: {}, {}", + fake_definition.id, fake_definition.word + ); + + assert_eq!( + fake_definition.embeddable().unwrap(), + OneOrMany::from(definition) + ) +} + +#[derive(Embeddable)] +struct Company { + id: String, + company: String, + #[embed] + employee_ages: Vec<i32>, +} + +#[test] +fn test_multiple_embed() { + let company = Company { + id: "doc1".to_string(), + company: "Google".to_string(), + employee_ages: vec![25, 30, 35, 40], + }; + + println!("Company: {}, {}", company.id, company.company); + + assert_eq!( + company.embeddable().unwrap(), + OneOrMany::try_from(vec![ + "25".to_string(), + "30".to_string(), + "35".to_string(), + "40".to_string() + ]) + .unwrap() + ); +} + +#[derive(Embeddable)] +struct Company2 { + id: String, + #[embed] + company: String, + #[embed] + employee_ages: Vec<i32>, +} + +#[test] +fn test_many_embed() { + let company = Company2 { + id: "doc1".to_string(), + company: "Google".to_string(), + employee_ages: vec![25, 30, 35, 40], + }; + + println!("Company2: {}", company.id); + + assert_eq!( + company.embeddable().unwrap(), + OneOrMany::try_from(vec![ + "Google".to_string(), + "25".to_string(), + "30".to_string(), + "35".to_string(), + "40".to_string() + ]) + .unwrap() + ); +} From 690027c7e524653f50273ec6154c0eb5929df352 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 16 Oct 2024 18:10:50 -0400 Subject: [PATCH 38/91] tests: add some more tests --- rig-core/tests/embeddable_macro.rs | 56 +++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/embeddable_macro.rs index 130da109..bc076fed 100644 --- a/rig-core/tests/embeddable_macro.rs +++ b/rig-core/tests/embeddable_macro.rs @@ -51,6 +51,43 @@ fn test_custom_embed() { #[derive(Embeddable)] struct FakeDefinition2 { + id: String, + #[embed] + word: String, + #[embed(embed_with = "serialize")] + definition: Definition, +} + +#[test] +fn test_custom_and_basic_embed() { + let fake_definition = FakeDefinition2 { + id: "doc1".to_string(), + word: "house".to_string(), + definition: Definition { + speech: "noun".to_string(), + word: "a building in which people live; residence for human beings.".to_string(), + link: "https://www.dictionary.com/browse/house".to_string(), + }, + }; + + println!( + "FakeDefinition: {}, {}", + fake_definition.id, fake_definition.word + ); + + assert_eq!( + fake_definition.embeddable().unwrap().first(), + "house".to_string() + ); + + assert_eq!( + fake_definition.embeddable().unwrap().rest(), + vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] + ) +} + +#[derive(Embeddable)] +struct FakeDefinition3 { id: String, word: String, #[embed] @@ -61,14 +98,14 @@ struct FakeDefinition2 { fn test_single_embed() { let definition = "a building in which people live; residence for human beings.".to_string(); - let fake_definition = FakeDefinition2 { + let fake_definition = FakeDefinition3 { id: "doc1".to_string(), word: "house".to_string(), definition: definition.clone(), }; println!( - "FakeDefinition2: {}, {}", + "FakeDefinition3: {}, {}", fake_definition.id, fake_definition.word ); @@ -87,7 +124,7 @@ struct Company { } #[test] -fn test_multiple_embed() { +fn test_multiple_embed_strings() { let company = Company { id: "doc1".to_string(), company: "Google".to_string(), @@ -96,8 +133,10 @@ fn test_multiple_embed() { println!("Company: {}, {}", company.id, company.company); + let result = company.embeddable().unwrap(); + assert_eq!( - company.embeddable().unwrap(), + result, OneOrMany::try_from(vec![ "25".to_string(), "30".to_string(), @@ -106,6 +145,13 @@ fn test_multiple_embed() { ]) .unwrap() ); + + assert_eq!(result.first(), "25".to_string()); + + assert_eq!( + result.rest(), + vec!["30".to_string(), "35".to_string(), "40".to_string()] + ) } #[derive(Embeddable)] @@ -118,7 +164,7 @@ struct Company2 { } #[test] -fn test_many_embed() { +fn test_multiple_embed_tags() { let company = Company2 { id: "doc1".to_string(), company: "Google".to_string(), From cca60593b903e68c8843e42397dd233b51f9e2c7 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 17 Oct 2024 09:09:17 -0400 Subject: [PATCH 39/91] clippy: cargo clippy --- rig-core/src/embeddings/embeddable.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index c4cf90da..6084e07e 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -66,7 +66,7 @@ impl<T: Clone> OneOrMany<T> { pub fn all(&self) -> Vec<T> { let mut all = vec![self.first.clone()]; - all.extend(self.rest.clone().into_iter()); + all.extend(self.rest.clone()); all } } @@ -89,9 +89,9 @@ impl<T> TryFrom<Vec<T>> for OneOrMany<T> { first: match iter.next() { Some(item) => item, None => { - return Err(EmbeddableError::Error(format!( - "Cannot convert empty Vec to OneOrMany" - ))) + return Err(EmbeddableError::Error( + "Cannot convert empty Vec to OneOrMany".to_string(), + )) } }, rest: iter.collect(), From f785b8c3aa64c2fb240899a39c3ae374e72aa394 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 17 Oct 2024 09:56:50 -0400 Subject: [PATCH 40/91] docs: add docstring to oneormany --- rig-core/src/embeddings/embeddable.rs | 80 +++++++++++++-------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index 6084e07e..9e751b00 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -3,45 +3,41 @@ //! ```rust //! use std::env; //! -//! use rig::Embeddable; //! use serde::{Deserialize, Serialize}; //! -//! #[derive(Embeddable)] //! struct FakeDefinition { //! id: String, //! word: String, -//! #[embed] //! definitions: Vec<String>, //! } //! -//! // Do something with FakeDefinition -//! // ... +//! let fake_definition = FakeDefinition { +//! id: "doc1".to_string(), +//! word: "hello".to_string(), +//! definition: "used as a greeting or to begin a conversation".to_string() +//! }; +//! +//! impl Embeddable for FakeDefinition { +//! type Error = anyhow::Error; +//! +//! fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { +//! // Embeddigns only need to be generated for `definition` field. +//! // Select it from te struct and return it as a single item. +//! Ok(OneOrMany::from(self.definition.clone())) +//! } +//! } //! ``` -/// The associated type `Kind` on the trait `Embeddable` must implement this trait. -pub trait EmbeddingKind {} - -/// Used for structs that contain a single embedding target. -pub struct SingleEmbedding; -impl EmbeddingKind for SingleEmbedding {} - -/// Used for structs that contain many embedding targets. -pub struct ManyEmbedding; -impl EmbeddingKind for ManyEmbedding {} - /// Error type used for when the `embeddable` method fails. /// Used by default implementations of `Embeddable` for common types. #[derive(Debug, thiserror::Error)] pub enum EmbeddableError { #[error("SerdeError: {0}")] SerdeError(#[from] serde_json::Error), - #[error("Error: {0}")] - Error(String), } /// Trait for types that can be embedded. -/// The `embeddable` method returns a list of strings for which embeddings will be generated by the embeddings builder. -/// If the type `Kind` is `SingleEmbedding`, the list of strings contains a single item, otherwise, the list can contain many items. +/// The `embeddable` method returns a OneOrMany<String> which contains strings for which embeddings will be generated by the embeddings builder. /// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. pub trait Embeddable { type Error: std::error::Error; @@ -49,21 +45,31 @@ pub trait Embeddable { fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error>; } +/// Struct containing either a single item or a list of items of type T. +/// If a single item is present, `first` will contain it and `rest` will be empty. +/// If multiple items are present, `first` will contain the first item and `rest` will contain the rest. +/// IMPORTANT: this struct cannot be created with an empty vector. +/// OneOrMany objects can only be create using OneOrMany::from() or OneOrMany::try_from(). #[derive(PartialEq, Eq, Debug)] pub struct OneOrMany<T> { + /// First item in the list. first: T, + /// Rest of the items in the list. rest: Vec<T>, } impl<T: Clone> OneOrMany<T> { + /// Get the first item in the list. pub fn first(&self) -> T { self.first.clone() } + /// Get the rest of the items in the list (excluding the first one). pub fn rest(&self) -> Vec<T> { self.rest.clone() } + /// Get all items in the list (joins the first with the rest). pub fn all(&self) -> Vec<T> { let mut all = vec![self.first.clone()]; all.extend(self.rest.clone()); @@ -71,6 +77,7 @@ impl<T: Clone> OneOrMany<T> { } } +/// Create a OneOrMany object with a single item. impl<T> From<T> for OneOrMany<T> { fn from(item: T) -> Self { OneOrMany { @@ -80,35 +87,29 @@ impl<T> From<T> for OneOrMany<T> { } } -impl<T> TryFrom<Vec<T>> for OneOrMany<T> { - type Error = EmbeddableError; - - fn try_from(items: Vec<T>) -> Result<Self, Self::Error> { +/// Create a OneOrMany object with a list of items. +impl<T> From<Vec<T>> for OneOrMany<T> { + fn from(items: Vec<T>) -> Self { let mut iter = items.into_iter(); - Ok(OneOrMany { + OneOrMany { first: match iter.next() { Some(item) => item, - None => { - return Err(EmbeddableError::Error( - "Cannot convert empty Vec to OneOrMany".to_string(), - )) - } + None => panic!("Cannot create OneOrMany with an empty vector."), }, rest: iter.collect(), - }) + } } } -impl<T: Clone> TryFrom<Vec<OneOrMany<T>>> for OneOrMany<T> { - type Error = EmbeddableError; - - fn try_from(value: Vec<OneOrMany<T>>) -> Result<Self, Self::Error> { +/// Merge a list of OneOrMany items into a single OneOrMany item. +impl<T: Clone> From<Vec<OneOrMany<T>>> for OneOrMany<T> { + fn from(value: Vec<OneOrMany<T>>) -> Self { let items = value .into_iter() .flat_map(|one_or_many| one_or_many.all()) .collect::<Vec<_>>(); - OneOrMany::try_from(items) + OneOrMany::from(items) } } @@ -206,15 +207,14 @@ impl Embeddable for serde_json::Value { } impl<T: Embeddable> Embeddable for Vec<T> { - type Error = EmbeddableError; + type Error = T::Error; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { let items = self .iter() .map(|item| item.embeddable()) - .collect::<Result<Vec<_>, _>>() - .map_err(|e| EmbeddableError::Error(e.to_string()))?; + .collect::<Result<Vec<_>, _>>()?; - OneOrMany::try_from(items) + Ok(OneOrMany::from(items)) } } From 0e2ade9b672baa174efc95cb509fe1c1288e74d3 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 17 Oct 2024 10:06:02 -0400 Subject: [PATCH 41/91] fix(macro): update error handling --- rig-core/rig-core-derive/src/embeddable.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index e36a7eea..563b8c3c 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -51,11 +51,10 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu fn embeddable(&self) -> Result<rig::embeddings::embeddable::OneOrMany<String>, Self::Error> { #target_stream; - rig::embeddings::embeddable::OneOrMany::try_from( + Ok(rig::embeddings::embeddable::OneOrMany::from( embed_targets.into_iter() - .collect::<Result<Vec<_>, _>>() - .map_err(|e| rig::embeddings::embeddable::EmbeddableError::Error(e.to_string()))? - ) + .collect::<Result<Vec<_>, _>>()? + )) } } }; From a98769c2228fa7101f7e4ce8296fa7bec132b01a Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 17 Oct 2024 10:40:27 -0400 Subject: [PATCH 42/91] refactor: reexport EmbeddingsBuilder in rig and update imports --- rig-core/examples/calculator_chatbot.rs | 3 ++- rig-core/examples/rag.rs | 3 ++- rig-core/examples/rag_dynamic_tools.rs | 3 ++- rig-core/examples/vector_search.rs | 3 ++- rig-core/examples/vector_search_cohere.rs | 3 ++- rig-core/src/embeddings/embeddable.rs | 2 +- rig-core/src/lib.rs | 3 ++- rig-core/src/providers/cohere.rs | 4 ++-- rig-core/src/providers/openai.rs | 9 +++------ rig-lancedb/examples/vector_search_local_ann.rs | 3 ++- rig-lancedb/examples/vector_search_local_enn.rs | 3 ++- rig-lancedb/examples/vector_search_s3_ann.rs | 3 ++- rig-mongodb/examples/vector_search_mongodb.rs | 3 ++- 13 files changed, 26 insertions(+), 19 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 949073b3..02a029e9 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -2,10 +2,11 @@ use anyhow::Result; use rig::{ cli_chatbot::cli_chatbot, completion::ToolDefinition, - embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, + embeddings::builder::DocumentEmbeddings, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, + EmbeddingsBuilder, }; use serde::{Deserialize, Serialize}; use serde_json::json; diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 936a3d05..ec9bebd8 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -2,9 +2,10 @@ use std::env; use rig::{ completion::Prompt, - embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, + embeddings::builder::DocumentEmbeddings, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, + EmbeddingsBuilder, }; #[tokio::main] diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index f00543a2..dce27693 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -1,10 +1,11 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, - embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, + embeddings::builder::DocumentEmbeddings, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, + EmbeddingsBuilder, }; use serde::{Deserialize, Serialize}; use serde_json::json; diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 26692706..67d04424 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,9 +1,10 @@ use std::env; use rig::{ - embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, + embeddings::builder::DocumentEmbeddings, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, + EmbeddingsBuilder, }; #[tokio::main] diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index c14fe0ce..1be5de8d 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,9 +1,10 @@ use std::env; use rig::{ - embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, + embeddings::builder::DocumentEmbeddings, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, + EmbeddingsBuilder, }; #[tokio::main] diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index 9e751b00..aafe2a5c 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -49,7 +49,7 @@ pub trait Embeddable { /// If a single item is present, `first` will contain it and `rest` will be empty. /// If multiple items are present, `first` will contain the first item and `rest` will contain the rest. /// IMPORTANT: this struct cannot be created with an empty vector. -/// OneOrMany objects can only be create using OneOrMany::from() or OneOrMany::try_from(). +/// OneOrMany objects can only be created using OneOrMany::from() or OneOrMany::try_from(). #[derive(PartialEq, Eq, Debug)] pub struct OneOrMany<T> { /// First item in the list. diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index cc17efa3..5c498c81 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -76,7 +76,8 @@ pub mod providers; pub mod tool; pub mod vector_store; -// Export Embeddable trait and Embeddable together. +// Re-export commonly used types and traits +pub use embeddings::builder::EmbeddingsBuilder; pub use embeddings::embeddable::Embeddable; #[cfg(feature = "rig_derive")] diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 87f2334d..844849d7 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -13,9 +13,9 @@ use std::collections::HashMap; use crate::{ agent::AgentBuilder, completion::{self, CompletionError}, - embeddings::{self, builder::EmbeddingsBuilder, embedding::EmbeddingError}, + embeddings::{self, embedding::EmbeddingError}, extractor::ExtractorBuilder, - json_utils, + json_utils, EmbeddingsBuilder, }; use schemars::JsonSchema; diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 9adf6680..6b95a56a 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -13,7 +13,7 @@ use crate::{ completion::{self, CompletionError, CompletionRequest}, embeddings::{self, embedding::EmbeddingError}, extractor::ExtractorBuilder, - json_utils, + json_utils, EmbeddingsBuilder, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -121,11 +121,8 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings( - &self, - model: &str, - ) -> embeddings::builder::EmbeddingsBuilder<EmbeddingModel> { - embeddings::builder::EmbeddingsBuilder::new(self.embedding_model(model)) + pub fn embeddings(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel> { + EmbeddingsBuilder::new(self.embedding_model(model)) } /// Create a completion model with the given name. diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 52ee213a..8f3ceee6 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -5,8 +5,9 @@ use fixture::{as_record_batch, schema}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::embedding::EmbeddingModel, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + EmbeddingsBuilder, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; use serde::Deserialize; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 9f0ec934..90a14fe0 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -3,9 +3,10 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::embedding::EmbeddingModel, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, + EmbeddingsBuilder, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 7ce61309..2d97ed80 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,9 +4,10 @@ use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::embedding::EmbeddingModel, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, + EmbeddingsBuilder, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; use serde::Deserialize; diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index cca7a6d4..517bfea6 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -2,9 +2,10 @@ use mongodb::{options::ClientOptions, Client as MongoClient, Collection}; use std::env; use rig::{ - embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, + embeddings::builder::DocumentEmbeddings, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, + EmbeddingsBuilder, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; From 067894cc27eab07be72445baed03d8928f0eeb74 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 17 Oct 2024 12:31:28 -0400 Subject: [PATCH 43/91] feat: implement IntoIterator and Iterator for OneOrMany --- rig-core/src/embeddings/embeddable.rs | 199 ++++++++++++++++++++++---- 1 file changed, 169 insertions(+), 30 deletions(-) diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index aafe2a5c..f5f09260 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -49,8 +49,8 @@ pub trait Embeddable { /// If a single item is present, `first` will contain it and `rest` will be empty. /// If multiple items are present, `first` will contain the first item and `rest` will contain the rest. /// IMPORTANT: this struct cannot be created with an empty vector. -/// OneOrMany objects can only be created using OneOrMany::from() or OneOrMany::try_from(). -#[derive(PartialEq, Eq, Debug)] +/// OneOrMany objects can only be created using OneOrMany::from_single() or OneOrMany::from_many(). +#[derive(PartialEq, Eq, Debug, Clone)] pub struct OneOrMany<T> { /// First item in the list. first: T, @@ -69,27 +69,16 @@ impl<T: Clone> OneOrMany<T> { self.rest.clone() } - /// Get all items in the list (joins the first with the rest). - pub fn all(&self) -> Vec<T> { - let mut all = vec![self.first.clone()]; - all.extend(self.rest.clone()); - all - } -} - -/// Create a OneOrMany object with a single item. -impl<T> From<T> for OneOrMany<T> { - fn from(item: T) -> Self { + /// Create a OneOrMany object with a single item of any type. + pub fn from_single(item: T) -> Self { OneOrMany { first: item, rest: vec![], } } -} -/// Create a OneOrMany object with a list of items. -impl<T> From<Vec<T>> for OneOrMany<T> { - fn from(items: Vec<T>) -> Self { + /// Create a OneOrMany object with a single item of any type. + pub fn from_many(items: Vec<T>) -> Self { let mut iter = items.into_iter(); OneOrMany { first: match iter.next() { @@ -99,6 +88,74 @@ impl<T> From<Vec<T>> for OneOrMany<T> { rest: iter.collect(), } } + + /// Use the Iterator trait on OneOrMany + pub fn iter(&self) -> OneOrManyIterator<T> { + OneOrManyIterator { + one_or_many: self, + index: 0, + } + } +} + +/// Implement Iterator for OneOrMany. +/// Iterates over all items in both `first` and `rest`. +/// Borrows the OneOrMany object that is being iterator over. +pub struct OneOrManyIterator<'a, T> { + one_or_many: &'a OneOrMany<T>, + index: usize, +} + +impl<'a, T> Iterator for OneOrManyIterator<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option<Self::Item> { + let mut item = None; + if self.index == 0 { + item = Some(&self.one_or_many.first) + } else if self.index - 1 < self.one_or_many.rest.len() { + item = Some(&self.one_or_many.rest[self.index - 1]); + }; + + self.index += 1; + item + } +} + +/// Implement IntoIterator for OneOrMany. +/// Iterates over all items in both `first` and `rest`. +/// Takes ownership the OneOrMany object that is being iterator over. +pub struct OneOrManyIntoIterator<T> { + one_or_many: OneOrMany<T>, + index: usize, +} + +impl<T: Clone> IntoIterator for OneOrMany<T> { + type Item = T; + type IntoIter = OneOrManyIntoIterator<T>; + + fn into_iter(self) -> OneOrManyIntoIterator<T> { + OneOrManyIntoIterator { + one_or_many: self, + index: 0, + } + } +} + +impl<T: Clone> Iterator for OneOrManyIntoIterator<T> { + type Item = T; + + fn next(&mut self) -> Option<Self::Item> { + let mut item = None; + if self.index == 0 { + item = Some(self.one_or_many.first()) + } else if self.index - 1 < self.one_or_many.rest.len() { + item = Some(self.one_or_many.rest[self.index - 1].clone()); + }; + + self.index += 1; + item + } } /// Merge a list of OneOrMany items into a single OneOrMany item. @@ -106,10 +163,10 @@ impl<T: Clone> From<Vec<OneOrMany<T>>> for OneOrMany<T> { fn from(value: Vec<OneOrMany<T>>) -> Self { let items = value .into_iter() - .flat_map(|one_or_many| one_or_many.all()) + .flat_map(|one_or_many| one_or_many.into_iter()) .collect::<Vec<_>>(); - OneOrMany::from(items) + OneOrMany::from_many(items) } } @@ -120,7 +177,7 @@ impl Embeddable for String { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from(self.clone())) + Ok(OneOrMany::from_single(self.clone())) } } @@ -128,7 +185,7 @@ impl Embeddable for i8 { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -136,7 +193,7 @@ impl Embeddable for i16 { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -144,7 +201,7 @@ impl Embeddable for i32 { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -152,7 +209,7 @@ impl Embeddable for i64 { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -160,7 +217,7 @@ impl Embeddable for i128 { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -168,7 +225,7 @@ impl Embeddable for f32 { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -176,7 +233,7 @@ impl Embeddable for f64 { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -184,7 +241,7 @@ impl Embeddable for bool { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -192,7 +249,7 @@ impl Embeddable for char { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -200,7 +257,7 @@ impl Embeddable for serde_json::Value { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from( + Ok(OneOrMany::from_single( serde_json::to_string(self).map_err(EmbeddableError::SerdeError)?, )) } @@ -218,3 +275,85 @@ impl<T: Embeddable> Embeddable for Vec<T> { Ok(OneOrMany::from(items)) } } + +#[cfg(test)] +mod test { + use super::OneOrMany; + + #[test] + fn test_one_or_many_iter_single() { + let one_or_many = OneOrMany::from_single("hello".to_string()); + + assert_eq!(one_or_many.iter().count(), 1); + + one_or_many.iter().for_each(|i| { + assert_eq!(i, "hello"); + }); + } + + #[test] + fn test_one_or_many_iter() { + let one_or_many = OneOrMany::from_many(vec!["hello".to_string(), "word".to_string()]); + + assert_eq!(one_or_many.iter().count(), 2); + + one_or_many.iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello"); + } + if i == 1 { + assert_eq!(item, "word"); + } + }); + } + + #[test] + fn test_one_or_many_into_iter_single() { + let one_or_many = OneOrMany::from_single("hello".to_string()); + + assert_eq!(one_or_many.clone().into_iter().count(), 1); + + one_or_many.into_iter().for_each(|i| { + assert_eq!(i, "hello".to_string()); + }); + } + + #[test] + fn test_one_or_many_into_iter() { + let one_or_many = OneOrMany::from_many(vec!["hello".to_string(), "word".to_string()]); + + assert_eq!(one_or_many.clone().into_iter().count(), 2); + + one_or_many.into_iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello".to_string()); + } + if i == 1 { + assert_eq!(item, "word".to_string()); + } + }); + } + + #[test] + fn test_one_or_many_merge() { + let one_or_many_1 = OneOrMany::from_many(vec!["hello".to_string(), "word".to_string()]); + + let one_or_many_2 = OneOrMany::from_single("sup".to_string()); + + let merged = OneOrMany::from(vec![one_or_many_1, one_or_many_2]); + + assert_eq!(merged.iter().count(), 3); + + merged.iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello"); + } + if i == 1 { + assert_eq!(item, "word"); + } + if i == 2 { + assert_eq!(item, "sup"); + } + }); + } +} From 32bcc61fa7af785e6ea4c5233a2734c0973ce6fd Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 17 Oct 2024 12:35:37 -0400 Subject: [PATCH 44/91] refactor: rename from methods --- rig-core/src/embeddings/embeddable.rs | 48 +++++++++++++-------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index f5f09260..a4bb85c5 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -22,8 +22,8 @@ //! //! fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { //! // Embeddigns only need to be generated for `definition` field. -//! // Select it from te struct and return it as a single item. -//! Ok(OneOrMany::from(self.definition.clone())) +//! // Select it from the struct and return it as a single item. +//! Ok(OneOrMany::one(self.definition.clone())) //! } //! } //! ``` @@ -49,7 +49,7 @@ pub trait Embeddable { /// If a single item is present, `first` will contain it and `rest` will be empty. /// If multiple items are present, `first` will contain the first item and `rest` will contain the rest. /// IMPORTANT: this struct cannot be created with an empty vector. -/// OneOrMany objects can only be created using OneOrMany::from_single() or OneOrMany::from_many(). +/// OneOrMany objects can only be created using OneOrMany::one() or OneOrMany::many(). #[derive(PartialEq, Eq, Debug, Clone)] pub struct OneOrMany<T> { /// First item in the list. @@ -70,15 +70,15 @@ impl<T: Clone> OneOrMany<T> { } /// Create a OneOrMany object with a single item of any type. - pub fn from_single(item: T) -> Self { + pub fn one(item: T) -> Self { OneOrMany { first: item, rest: vec![], } } - /// Create a OneOrMany object with a single item of any type. - pub fn from_many(items: Vec<T>) -> Self { + /// Create a OneOrMany object with a vector of items of any type. + pub fn many(items: Vec<T>) -> Self { let mut iter = items.into_iter(); OneOrMany { first: match iter.next() { @@ -166,7 +166,7 @@ impl<T: Clone> From<Vec<OneOrMany<T>>> for OneOrMany<T> { .flat_map(|one_or_many| one_or_many.into_iter()) .collect::<Vec<_>>(); - OneOrMany::from_many(items) + OneOrMany::many(items) } } @@ -177,7 +177,7 @@ impl Embeddable for String { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from_single(self.clone())) + Ok(OneOrMany::one(self.clone())) } } @@ -185,7 +185,7 @@ impl Embeddable for i8 { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -193,7 +193,7 @@ impl Embeddable for i16 { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -201,7 +201,7 @@ impl Embeddable for i32 { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -209,7 +209,7 @@ impl Embeddable for i64 { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -217,7 +217,7 @@ impl Embeddable for i128 { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -225,7 +225,7 @@ impl Embeddable for f32 { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -233,7 +233,7 @@ impl Embeddable for f64 { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -241,7 +241,7 @@ impl Embeddable for bool { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -249,7 +249,7 @@ impl Embeddable for char { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -257,7 +257,7 @@ impl Embeddable for serde_json::Value { type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::from_single( + Ok(OneOrMany::one( serde_json::to_string(self).map_err(EmbeddableError::SerdeError)?, )) } @@ -282,7 +282,7 @@ mod test { #[test] fn test_one_or_many_iter_single() { - let one_or_many = OneOrMany::from_single("hello".to_string()); + let one_or_many = OneOrMany::one("hello".to_string()); assert_eq!(one_or_many.iter().count(), 1); @@ -293,7 +293,7 @@ mod test { #[test] fn test_one_or_many_iter() { - let one_or_many = OneOrMany::from_many(vec!["hello".to_string(), "word".to_string()]); + let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]); assert_eq!(one_or_many.iter().count(), 2); @@ -309,7 +309,7 @@ mod test { #[test] fn test_one_or_many_into_iter_single() { - let one_or_many = OneOrMany::from_single("hello".to_string()); + let one_or_many = OneOrMany::one("hello".to_string()); assert_eq!(one_or_many.clone().into_iter().count(), 1); @@ -320,7 +320,7 @@ mod test { #[test] fn test_one_or_many_into_iter() { - let one_or_many = OneOrMany::from_many(vec!["hello".to_string(), "word".to_string()]); + let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]); assert_eq!(one_or_many.clone().into_iter().count(), 2); @@ -336,9 +336,9 @@ mod test { #[test] fn test_one_or_many_merge() { - let one_or_many_1 = OneOrMany::from_many(vec!["hello".to_string(), "word".to_string()]); + let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]); - let one_or_many_2 = OneOrMany::from_single("sup".to_string()); + let one_or_many_2 = OneOrMany::one("sup".to_string()); let merged = OneOrMany::from(vec![one_or_many_1, one_or_many_2]); From 564bef408a1d8c22d317015fe3b697839869dcd0 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 17 Oct 2024 12:50:47 -0400 Subject: [PATCH 45/91] tests: fix failing tests --- rig-core/tests/embeddable_macro.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/embeddable_macro.rs index bc076fed..96e42c7f 100644 --- a/rig-core/tests/embeddable_macro.rs +++ b/rig-core/tests/embeddable_macro.rs @@ -3,7 +3,7 @@ use rig::Embeddable; use serde::Serialize; fn serialize(definition: Definition) -> Result<OneOrMany<String>, EmbeddableError> { - Ok(OneOrMany::from( + Ok(OneOrMany::one( serde_json::to_string(&definition).map_err(EmbeddableError::SerdeError)?, )) } @@ -42,7 +42,7 @@ fn test_custom_embed() { assert_eq!( fake_definition.embeddable().unwrap(), - OneOrMany::from( + OneOrMany::one( "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() ) @@ -111,7 +111,7 @@ fn test_single_embed() { assert_eq!( fake_definition.embeddable().unwrap(), - OneOrMany::from(definition) + OneOrMany::one(definition) ) } @@ -137,13 +137,12 @@ fn test_multiple_embed_strings() { assert_eq!( result, - OneOrMany::try_from(vec![ + OneOrMany::many(vec![ "25".to_string(), "30".to_string(), "35".to_string(), "40".to_string() ]) - .unwrap() ); assert_eq!(result.first(), "25".to_string()); @@ -175,13 +174,12 @@ fn test_multiple_embed_tags() { assert_eq!( company.embeddable().unwrap(), - OneOrMany::try_from(vec![ + OneOrMany::many(vec![ "Google".to_string(), "25".to_string(), "30".to_string(), "35".to_string(), "40".to_string() ]) - .unwrap() ); } From 04f1f3e3f488d7a0c6b7f4445cb16ab223abe981 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 17 Oct 2024 14:39:23 -0400 Subject: [PATCH 46/91] refactor&fix: make PR review changes --- rig-core/Cargo.toml | 4 +- rig-core/examples/calculator_chatbot.rs | 2 +- rig-core/examples/rag.rs | 2 +- rig-core/examples/rag_dynamic_tools.rs | 2 +- rig-core/examples/vector_search.rs | 2 +- rig-core/examples/vector_search_cohere.rs | 2 +- rig-core/src/embeddings/builder.rs | 7 +- rig-core/src/embeddings/embeddable.rs | 238 ++---------------- rig-core/src/embeddings/embedding.rs | 2 +- rig-core/src/embeddings/mod.rs | 4 + rig-core/src/lib.rs | 4 +- rig-core/src/providers/cohere.rs | 10 +- rig-core/src/providers/openai.rs | 10 +- rig-core/src/vec_utils.rs | 191 ++++++++++++++ rig-core/src/vector_store/in_memory_store.rs | 2 +- rig-core/src/vector_store/mod.rs | 2 +- .../examples/vector_search_local_ann.rs | 5 +- .../examples/vector_search_local_enn.rs | 3 +- rig-lancedb/examples/vector_search_s3_ann.rs | 3 +- rig-mongodb/examples/vector_search_mongodb.rs | 2 +- rig-mongodb/src/lib.rs | 3 +- 21 files changed, 244 insertions(+), 256 deletions(-) create mode 100644 rig-core/src/vec_utils.rs diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index d2ab06c6..bc65409a 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -31,8 +31,8 @@ tokio = { version = "1.34.0", features = ["full"] } tracing-subscriber = "0.3.18" [features] -rig_derive = ["dep:rig-derive"] +derive = ["dep:rig-derive"] [[test]] name = "embeddable_macro" -required-features = ["rig_derive"] \ No newline at end of file +required-features = ["derive"] \ No newline at end of file diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 02a029e9..8b622482 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -3,10 +3,10 @@ use rig::{ cli_chatbot::cli_chatbot, completion::ToolDefinition, embeddings::builder::DocumentEmbeddings, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, - EmbeddingsBuilder, }; use serde::{Deserialize, Serialize}; use serde_json::json; diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index ec9bebd8..674d028c 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -3,9 +3,9 @@ use std::env; use rig::{ completion::Prompt, embeddings::builder::DocumentEmbeddings, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, - EmbeddingsBuilder, }; #[tokio::main] diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index dce27693..5476eb80 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -2,10 +2,10 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, embeddings::builder::DocumentEmbeddings, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, - EmbeddingsBuilder, }; use serde::{Deserialize, Serialize}; use serde_json::json; diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 67d04424..feadbfb4 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -2,9 +2,9 @@ use std::env; use rig::{ embeddings::builder::DocumentEmbeddings, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - EmbeddingsBuilder, }; #[tokio::main] diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 1be5de8d..6b93bcdd 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -2,9 +2,9 @@ use std::env; use rig::{ embeddings::builder::DocumentEmbeddings, + embeddings::EmbeddingsBuilder, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - EmbeddingsBuilder, }; #[tokio::main] diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 2458b98d..46eef092 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -30,9 +30,10 @@ use std::{cmp::max, collections::HashMap}; use futures::{stream, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; -use crate::tool::{ToolEmbedding, ToolSet, ToolType}; - -use super::embedding::{Embedding, EmbeddingError, EmbeddingModel}; +use crate::{ + embeddings::{Embedding, EmbeddingError, EmbeddingModel}, + tool::{ToolEmbedding, ToolSet, ToolType}, +}; /// Struct that holds a document and its embeddings. /// diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index a4bb85c5..e303ecda 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -1,5 +1,5 @@ //! The module defines the [Embeddable] trait, which must be implemented for types that can be embedded. -//! //! # Example +//! # Example //! ```rust //! use std::env; //! @@ -8,7 +8,7 @@ //! struct FakeDefinition { //! id: String, //! word: String, -//! definitions: Vec<String>, +//! definition: String, //! } //! //! let fake_definition = FakeDefinition { @@ -28,151 +28,26 @@ //! } //! ``` +use crate::vec_utils::OneOrMany; + /// Error type used for when the `embeddable` method fails. /// Used by default implementations of `Embeddable` for common types. #[derive(Debug, thiserror::Error)] -pub enum EmbeddableError { - #[error("SerdeError: {0}")] - SerdeError(#[from] serde_json::Error), -} +#[error("{0}")] +pub struct EmbeddableError(#[from] Box<dyn std::error::Error + Send + Sync>); /// Trait for types that can be embedded. -/// The `embeddable` method returns a OneOrMany<String> which contains strings for which embeddings will be generated by the embeddings builder. +/// The `embeddable` method returns a `OneOrMany<String>` which contains strings for which embeddings will be generated by the embeddings builder. /// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. pub trait Embeddable { - type Error: std::error::Error; + type Error: std::error::Error + Sync + Send + 'static; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error>; } -/// Struct containing either a single item or a list of items of type T. -/// If a single item is present, `first` will contain it and `rest` will be empty. -/// If multiple items are present, `first` will contain the first item and `rest` will contain the rest. -/// IMPORTANT: this struct cannot be created with an empty vector. -/// OneOrMany objects can only be created using OneOrMany::one() or OneOrMany::many(). -#[derive(PartialEq, Eq, Debug, Clone)] -pub struct OneOrMany<T> { - /// First item in the list. - first: T, - /// Rest of the items in the list. - rest: Vec<T>, -} - -impl<T: Clone> OneOrMany<T> { - /// Get the first item in the list. - pub fn first(&self) -> T { - self.first.clone() - } - - /// Get the rest of the items in the list (excluding the first one). - pub fn rest(&self) -> Vec<T> { - self.rest.clone() - } - - /// Create a OneOrMany object with a single item of any type. - pub fn one(item: T) -> Self { - OneOrMany { - first: item, - rest: vec![], - } - } - - /// Create a OneOrMany object with a vector of items of any type. - pub fn many(items: Vec<T>) -> Self { - let mut iter = items.into_iter(); - OneOrMany { - first: match iter.next() { - Some(item) => item, - None => panic!("Cannot create OneOrMany with an empty vector."), - }, - rest: iter.collect(), - } - } - - /// Use the Iterator trait on OneOrMany - pub fn iter(&self) -> OneOrManyIterator<T> { - OneOrManyIterator { - one_or_many: self, - index: 0, - } - } -} - -/// Implement Iterator for OneOrMany. -/// Iterates over all items in both `first` and `rest`. -/// Borrows the OneOrMany object that is being iterator over. -pub struct OneOrManyIterator<'a, T> { - one_or_many: &'a OneOrMany<T>, - index: usize, -} - -impl<'a, T> Iterator for OneOrManyIterator<'a, T> { - type Item = &'a T; - - fn next(&mut self) -> Option<Self::Item> { - let mut item = None; - if self.index == 0 { - item = Some(&self.one_or_many.first) - } else if self.index - 1 < self.one_or_many.rest.len() { - item = Some(&self.one_or_many.rest[self.index - 1]); - }; - - self.index += 1; - item - } -} - -/// Implement IntoIterator for OneOrMany. -/// Iterates over all items in both `first` and `rest`. -/// Takes ownership the OneOrMany object that is being iterator over. -pub struct OneOrManyIntoIterator<T> { - one_or_many: OneOrMany<T>, - index: usize, -} - -impl<T: Clone> IntoIterator for OneOrMany<T> { - type Item = T; - type IntoIter = OneOrManyIntoIterator<T>; - - fn into_iter(self) -> OneOrManyIntoIterator<T> { - OneOrManyIntoIterator { - one_or_many: self, - index: 0, - } - } -} - -impl<T: Clone> Iterator for OneOrManyIntoIterator<T> { - type Item = T; - - fn next(&mut self) -> Option<Self::Item> { - let mut item = None; - if self.index == 0 { - item = Some(self.one_or_many.first()) - } else if self.index - 1 < self.one_or_many.rest.len() { - item = Some(self.one_or_many.rest[self.index - 1].clone()); - }; - - self.index += 1; - item - } -} - -/// Merge a list of OneOrMany items into a single OneOrMany item. -impl<T: Clone> From<Vec<OneOrMany<T>>> for OneOrMany<T> { - fn from(value: Vec<OneOrMany<T>>) -> Self { - let items = value - .into_iter() - .flat_map(|one_or_many| one_or_many.into_iter()) - .collect::<Vec<_>>(); - - OneOrMany::many(items) - } -} - -////////////////////////////////////////////////////// -/// Implementations of Embeddable for common types /// -////////////////////////////////////////////////////// +// ================================================================ +// Implementations of Embeddable for common types +// ================================================================ impl Embeddable for String { type Error = EmbeddableError; @@ -258,102 +133,21 @@ impl Embeddable for serde_json::Value { fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { Ok(OneOrMany::one( - serde_json::to_string(self).map_err(EmbeddableError::SerdeError)?, + serde_json::to_string(self).map_err(|e| EmbeddableError(Box::new(e)))?, )) } } impl<T: Embeddable> Embeddable for Vec<T> { - type Error = T::Error; + type Error = EmbeddableError; fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { let items = self .iter() .map(|item| item.embeddable()) - .collect::<Result<Vec<_>, _>>()?; - - Ok(OneOrMany::from(items)) - } -} - -#[cfg(test)] -mod test { - use super::OneOrMany; - - #[test] - fn test_one_or_many_iter_single() { - let one_or_many = OneOrMany::one("hello".to_string()); - - assert_eq!(one_or_many.iter().count(), 1); - - one_or_many.iter().for_each(|i| { - assert_eq!(i, "hello"); - }); - } - - #[test] - fn test_one_or_many_iter() { - let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]); - - assert_eq!(one_or_many.iter().count(), 2); - - one_or_many.iter().enumerate().for_each(|(i, item)| { - if i == 0 { - assert_eq!(item, "hello"); - } - if i == 1 { - assert_eq!(item, "word"); - } - }); - } - - #[test] - fn test_one_or_many_into_iter_single() { - let one_or_many = OneOrMany::one("hello".to_string()); - - assert_eq!(one_or_many.clone().into_iter().count(), 1); - - one_or_many.into_iter().for_each(|i| { - assert_eq!(i, "hello".to_string()); - }); - } - - #[test] - fn test_one_or_many_into_iter() { - let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]); - - assert_eq!(one_or_many.clone().into_iter().count(), 2); - - one_or_many.into_iter().enumerate().for_each(|(i, item)| { - if i == 0 { - assert_eq!(item, "hello".to_string()); - } - if i == 1 { - assert_eq!(item, "word".to_string()); - } - }); - } - - #[test] - fn test_one_or_many_merge() { - let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]); - - let one_or_many_2 = OneOrMany::one("sup".to_string()); - - let merged = OneOrMany::from(vec![one_or_many_1, one_or_many_2]); - - assert_eq!(merged.iter().count(), 3); + .collect::<Result<Vec<_>, _>>() + .map_err(|e| EmbeddableError(Box::new(e)))?; - merged.iter().enumerate().for_each(|(i, item)| { - if i == 0 { - assert_eq!(item, "hello"); - } - if i == 1 { - assert_eq!(item, "word"); - } - if i == 2 { - assert_eq!(item, "sup"); - } - }); + OneOrMany::merge(items).map_err(|e| EmbeddableError(Box::new(e))) } } diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index ff284a05..81a980ef 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -1,5 +1,5 @@ //! The module defines the [EmbeddingModel] trait, which represents an embedding model that can -//! generate embeddings for documents. It also provides an implementation of the [EmbeddingsBuilder] +//! generate embeddings for documents. It also provides an implementation of the [embeddings::EmbeddingsBuilder] //! struct, which allows users to build collections of document embeddings using different embedding //! models and document sources. //! diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 37e720cb..a9eda7c3 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -6,3 +6,7 @@ pub mod builder; pub mod embeddable; pub mod embedding; + +pub use builder::EmbeddingsBuilder; +pub use embeddable::Embeddable; +pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 5c498c81..219d9be0 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -74,11 +74,11 @@ pub mod extractor; pub mod json_utils; pub mod providers; pub mod tool; +mod vec_utils; pub mod vector_store; // Re-export commonly used types and traits -pub use embeddings::builder::EmbeddingsBuilder; pub use embeddings::embeddable::Embeddable; -#[cfg(feature = "rig_derive")] +#[cfg(feature = "derive")] pub use rig_derive::Embeddable; diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 844849d7..ae874b21 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -13,9 +13,9 @@ use std::collections::HashMap; use crate::{ agent::AgentBuilder, completion::{self, CompletionError}, - embeddings::{self, embedding::EmbeddingError}, + embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, EmbeddingsBuilder, + json_utils, }; use schemars::JsonSchema; @@ -183,7 +183,7 @@ pub struct EmbeddingModel { ndims: usize, } -impl embeddings::embedding::EmbeddingModel for EmbeddingModel { +impl embeddings::EmbeddingModel for EmbeddingModel { const MAX_DOCUMENTS: usize = 96; fn ndims(&self) -> usize { @@ -193,7 +193,7 @@ impl embeddings::embedding::EmbeddingModel for EmbeddingModel { async fn embed_documents( &self, documents: Vec<String>, - ) -> Result<Vec<embeddings::embedding::Embedding>, EmbeddingError> { + ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> { let response = self .client .post("/v1/embed") @@ -222,7 +222,7 @@ impl embeddings::embedding::EmbeddingModel for EmbeddingModel { .embeddings .into_iter() .zip(documents.into_iter()) - .map(|(embedding, document)| embeddings::embedding::Embedding { + .map(|(embedding, document)| embeddings::Embedding { document, vec: embedding, }) diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 6b95a56a..c9ba9afa 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -11,9 +11,9 @@ use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, - embeddings::{self, embedding::EmbeddingError}, + embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, EmbeddingsBuilder, + json_utils, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -232,7 +232,7 @@ pub struct EmbeddingModel { ndims: usize, } -impl embeddings::embedding::EmbeddingModel for EmbeddingModel { +impl embeddings::EmbeddingModel for EmbeddingModel { const MAX_DOCUMENTS: usize = 1024; fn ndims(&self) -> usize { @@ -242,7 +242,7 @@ impl embeddings::embedding::EmbeddingModel for EmbeddingModel { async fn embed_documents( &self, documents: Vec<String>, - ) -> Result<Vec<embeddings::embedding::Embedding>, EmbeddingError> { + ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> { let response = self .client .post("/v1/embeddings") @@ -268,7 +268,7 @@ impl embeddings::embedding::EmbeddingModel for EmbeddingModel { .data .into_iter() .zip(documents.into_iter()) - .map(|(embedding, document)| embeddings::embedding::Embedding { + .map(|(embedding, document)| embeddings::Embedding { document, vec: embedding.embedding, }) diff --git a/rig-core/src/vec_utils.rs b/rig-core/src/vec_utils.rs new file mode 100644 index 00000000..d10a8608 --- /dev/null +++ b/rig-core/src/vec_utils.rs @@ -0,0 +1,191 @@ +/// Struct containing either a single item or a list of items of type T. +/// If a single item is present, `first` will contain it and `rest` will be empty. +/// If multiple items are present, `first` will contain the first item and `rest` will contain the rest. +/// IMPORTANT: this struct cannot be created with an empty vector. +/// OneOrMany objects can only be created using OneOrMany::from() or OneOrMany::try_from(). +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct OneOrMany<T> { + /// First item in the list. + first: T, + /// Rest of the items in the list. + rest: Vec<T>, +} + +/// Error type for when trying to create a OneOrMany object with an empty vector. +#[derive(Debug)] +pub struct EmptyListError; + +impl std::fmt::Display for EmptyListError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Cannot create OneOrMany with an empty vector.") + } +} +impl std::error::Error for EmptyListError {} + +impl<T: Clone> OneOrMany<T> { + /// Get the first item in the list. + pub fn first(&self) -> T { + self.first.clone() + } + + /// Get the rest of the items in the list (excluding the first one). + pub fn rest(&self) -> Vec<T> { + self.rest.clone() + } + + /// Use the Iterator trait on OneOrMany + pub fn iter(&self) -> OneOrManyIterator<T> { + OneOrManyIterator { + one_or_many: self, + index: 0, + } + } + + /// Create a OneOrMany object with a single item of any type. + pub fn one(item: T) -> Self { + OneOrMany { + first: item, + rest: vec![], + } + } + + /// Create a OneOrMany object with a vector of items of any type. + pub fn many(items: Vec<T>) -> Result<Self, EmptyListError> { + let mut iter = items.into_iter(); + Ok(OneOrMany { + first: match iter.next() { + Some(item) => item, + None => return Err(EmptyListError), + }, + rest: iter.collect(), + }) + } + + /// Merge a list of OneOrMany items into a single OneOrMany item. + pub fn merge(one_or_many_items: Vec<OneOrMany<T>>) -> Result<Self, EmptyListError> { + let items = one_or_many_items + .into_iter() + .flat_map(|one_or_many| one_or_many.into_iter()) + .collect::<Vec<_>>(); + + OneOrMany::many(items) + } +} + +/// Implement Iterator for OneOrMany. +/// Iterates over all items in both `first` and `rest`. +/// Borrows the OneOrMany object that is being iterator over. +pub struct OneOrManyIterator<'a, T> { + one_or_many: &'a OneOrMany<T>, + index: usize, +} + +impl<'a, T> Iterator for OneOrManyIterator<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option<Self::Item> { + let mut item = None; + if self.index == 0 { + item = Some(&self.one_or_many.first) + } else if self.index - 1 < self.one_or_many.rest.len() { + item = Some(&self.one_or_many.rest[self.index - 1]); + }; + + self.index += 1; + item + } +} + +/// Implement IntoIterator for OneOrMany. +/// Iterates over all items in both `first` and `rest`. +/// Takes ownership the OneOrMany object that is being iterator over. +impl<T: Clone> IntoIterator for OneOrMany<T> { + type Item = T; + type IntoIter = std::iter::Chain<std::iter::Once<T>, std::vec::IntoIter<T>>; + + fn into_iter(self) -> Self::IntoIter { + std::iter::once(self.first).chain(self.rest) + } +} + +#[cfg(test)] +mod test { + use super::OneOrMany; + + #[test] + fn test_one_or_many_iter_single() { + let one_or_many = OneOrMany::one("hello".to_string()); + + assert_eq!(one_or_many.iter().count(), 1); + + one_or_many.iter().for_each(|i| { + assert_eq!(i, "hello"); + }); + } + + #[test] + fn test_one_or_many_iter() { + let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + assert_eq!(one_or_many.iter().count(), 2); + + one_or_many.iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello"); + } + if i == 1 { + assert_eq!(item, "word"); + } + }); + } + + #[test] + fn test_one_or_many_into_iter_single() { + let one_or_many = OneOrMany::one("hello".to_string()); + + assert_eq!(one_or_many.clone().into_iter().count(), 1); + + one_or_many.into_iter().for_each(|i| { + assert_eq!(i, "hello".to_string()); + }); + } + + #[test] + fn test_one_or_many_into_iter() { + let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + assert_eq!(one_or_many.clone().into_iter().count(), 2); + + one_or_many.into_iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello".to_string()); + } + if i == 1 { + assert_eq!(item, "word".to_string()); + } + }); + } + + #[test] + fn test_one_or_many_merge() { + let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + let one_or_many_2 = OneOrMany::one("sup".to_string()); + + let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap(); + + assert_eq!(merged.iter().count(), 3); + + merged.iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello"); + } + if i == 1 { + assert_eq!(item, "word"); + } + if i == 2 { + assert_eq!(item, "sup"); + } + }); + } +} diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 9bbc85f1..bfe2bd29 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -8,7 +8,7 @@ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; use super::{VectorStoreError, VectorStoreIndex}; -use crate::embeddings::embedding::{Embedding, EmbeddingModel}; +use crate::embeddings::{Embedding, EmbeddingModel}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 6f112b81..396b5514 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -2,7 +2,7 @@ use futures::future::BoxFuture; use serde::Deserialize; use serde_json::Value; -use crate::embeddings::embedding::EmbeddingError; +use crate::embeddings::EmbeddingError; pub mod in_memory_store; diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 8f3ceee6..6466c2b9 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -3,11 +3,10 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use lancedb::index::vector::IvfPqIndexBuilder; -use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::embedding::EmbeddingModel, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - EmbeddingsBuilder, + vector_store::VectorStoreIndex, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; use serde::Deserialize; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 90a14fe0..5932dcd0 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -3,10 +3,9 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use rig::{ - embeddings::embedding::EmbeddingModel, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, - EmbeddingsBuilder, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 2d97ed80..70f0c8c5 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,10 +4,9 @@ use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::embedding::EmbeddingModel, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, - EmbeddingsBuilder, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; use serde::Deserialize; diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 517bfea6..caba89d8 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -3,9 +3,9 @@ use std::env; use rig::{ embeddings::builder::DocumentEmbeddings, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, - EmbeddingsBuilder, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 17dda463..c3973092 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -2,7 +2,8 @@ use futures::StreamExt; use mongodb::bson::{self, doc}; use rig::{ - embeddings::{builder::DocumentEmbeddings, embedding::Embedding, embedding::EmbeddingModel}, + embeddings::builder::DocumentEmbeddings, + embeddings::{Embedding, EmbeddingModel}, vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; From c8f6646076ce27d3f7a4513ac603aba09436cd81 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 17 Oct 2024 15:00:32 -0400 Subject: [PATCH 47/91] fix: fix tests failing --- rig-core/rig-core-derive/src/embeddable.rs | 6 +++--- rig-core/src/embeddings/embeddable.rs | 13 ++++++++++--- rig-core/src/lib.rs | 3 ++- rig-core/tests/embeddable_macro.rs | 8 +++++--- 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index 563b8c3c..4c533d1e 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -48,13 +48,13 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu impl #impl_generics Embeddable for #name #ty_generics #where_clause { type Error = rig::embeddings::embeddable::EmbeddableError; - fn embeddable(&self) -> Result<rig::embeddings::embeddable::OneOrMany<String>, Self::Error> { + fn embeddable(&self) -> Result<rig::OneOrMany<String>, Self::Error> { #target_stream; - Ok(rig::embeddings::embeddable::OneOrMany::from( + rig::OneOrMany::merge( embed_targets.into_iter() .collect::<Result<Vec<_>, _>>()? - )) + ).map_err(rig::embeddings::embeddable::EmbeddableError::new) } } }; diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index e303ecda..accdf402 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -4,6 +4,7 @@ //! use std::env; //! //! use serde::{Deserialize, Serialize}; +//! use rig::OneOrMany; //! //! struct FakeDefinition { //! id: String, @@ -36,6 +37,12 @@ use crate::vec_utils::OneOrMany; #[error("{0}")] pub struct EmbeddableError(#[from] Box<dyn std::error::Error + Send + Sync>); +impl EmbeddableError { + pub fn new<E: std::error::Error + Send + Sync + 'static>(error: E) -> Self { + EmbeddableError(Box::new(error)) + } +} + /// Trait for types that can be embedded. /// The `embeddable` method returns a `OneOrMany<String>` which contains strings for which embeddings will be generated by the embeddings builder. /// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. @@ -133,7 +140,7 @@ impl Embeddable for serde_json::Value { fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { Ok(OneOrMany::one( - serde_json::to_string(self).map_err(|e| EmbeddableError(Box::new(e)))?, + serde_json::to_string(self).map_err(EmbeddableError::new)?, )) } } @@ -146,8 +153,8 @@ impl<T: Embeddable> Embeddable for Vec<T> { .iter() .map(|item| item.embeddable()) .collect::<Result<Vec<_>, _>>() - .map_err(|e| EmbeddableError(Box::new(e)))?; + .map_err(EmbeddableError::new)?; - OneOrMany::merge(items).map_err(|e| EmbeddableError(Box::new(e))) + OneOrMany::merge(items).map_err(EmbeddableError::new) } } diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 219d9be0..6997c6ff 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -74,11 +74,12 @@ pub mod extractor; pub mod json_utils; pub mod providers; pub mod tool; -mod vec_utils; +pub mod vec_utils; pub mod vector_store; // Re-export commonly used types and traits pub use embeddings::embeddable::Embeddable; +pub use vec_utils::OneOrMany; #[cfg(feature = "derive")] pub use rig_derive::Embeddable; diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/embeddable_macro.rs index 96e42c7f..d30c5bfa 100644 --- a/rig-core/tests/embeddable_macro.rs +++ b/rig-core/tests/embeddable_macro.rs @@ -1,10 +1,10 @@ -use rig::embeddings::embeddable::{EmbeddableError, OneOrMany}; -use rig::Embeddable; +use rig::embeddings::embeddable::EmbeddableError; +use rig::{Embeddable, OneOrMany}; use serde::Serialize; fn serialize(definition: Definition) -> Result<OneOrMany<String>, EmbeddableError> { Ok(OneOrMany::one( - serde_json::to_string(&definition).map_err(EmbeddableError::SerdeError)?, + serde_json::to_string(&definition).map_err(EmbeddableError::new)?, )) } @@ -143,6 +143,7 @@ fn test_multiple_embed_strings() { "35".to_string(), "40".to_string() ]) + .unwrap() ); assert_eq!(result.first(), "25".to_string()); @@ -181,5 +182,6 @@ fn test_multiple_embed_tags() { "35".to_string(), "40".to_string() ]) + .unwrap() ); } From 40f3c1868137e5ad52d97de544c7c22a0792f349 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 17 Oct 2024 15:21:55 -0400 Subject: [PATCH 48/91] test: add test on OneOrMany --- rig-core/rig-core-derive/src/basic.rs | 2 +- rig-core/rig-core-derive/src/embeddable.rs | 1 - rig-core/src/vec_utils.rs | 7 +++++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/rig-core/rig-core-derive/src/basic.rs b/rig-core/rig-core-derive/src/basic.rs index 7ac5bb47..942e6c7e 100644 --- a/rig-core/rig-core-derive/src/basic.rs +++ b/rig-core/rig-core-derive/src/basic.rs @@ -19,7 +19,7 @@ pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item }) } -// Adds bounds to where clause that force all fields tagged with #[embed] to implement the Embeddable trait. +/// Adds bounds to where clause that force all fields tagged with #[embed] to implement the Embeddable trait. pub(crate) fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { let where_clause = generics.make_where_clause(); diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index 4c533d1e..13299ba6 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -26,7 +26,6 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu )); } - // Determine whether the Embeddable::Kind should be SingleEmbedding or ManyEmbedding quote! { let mut embed_targets = #basic_targets; embed_targets.extend(#custom_targets) diff --git a/rig-core/src/vec_utils.rs b/rig-core/src/vec_utils.rs index d10a8608..d8435685 100644 --- a/rig-core/src/vec_utils.rs +++ b/rig-core/src/vec_utils.rs @@ -188,4 +188,11 @@ mod test { } }); } + + #[test] + fn test_one_or_many_error() { + assert!( + OneOrMany::<String>::many(vec![]).is_err() + ) + } } From 68d88b6c7cf8594d49d770a1125783592d8e80cc Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 17 Oct 2024 15:22:14 -0400 Subject: [PATCH 49/91] style: cargo fmt --- rig-core/src/vec_utils.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/rig-core/src/vec_utils.rs b/rig-core/src/vec_utils.rs index d8435685..a487e71c 100644 --- a/rig-core/src/vec_utils.rs +++ b/rig-core/src/vec_utils.rs @@ -191,8 +191,6 @@ mod test { #[test] fn test_one_or_many_error() { - assert!( - OneOrMany::<String>::many(vec![]).is_err() - ) + assert!(OneOrMany::<String>::many(vec![]).is_err()) } } From 4bc7d07e41de81c562472db848e613d2fdb0292d Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 17 Oct 2024 18:02:11 -0400 Subject: [PATCH 50/91] docs&fix: fix doc strings, implement iter_mut for OneOrMany --- rig-core/src/embeddings/embeddable.rs | 2 +- rig-core/src/embeddings/embedding.rs | 2 +- rig-core/src/lib.rs | 8 +- rig-core/src/{vec_utils.rs => one_or_many.rs} | 143 ++++++++++++++---- 4 files changed, 118 insertions(+), 37 deletions(-) rename rig-core/src/{vec_utils.rs => one_or_many.rs} (59%) diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index accdf402..f5a69fd6 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -29,7 +29,7 @@ //! } //! ``` -use crate::vec_utils::OneOrMany; +use crate::one_or_many::OneOrMany; /// Error type used for when the `embeddable` method fails. /// Used by default implementations of `Embeddable` for common types. diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index 81a980ef..735284a8 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -1,5 +1,5 @@ //! The module defines the [EmbeddingModel] trait, which represents an embedding model that can -//! generate embeddings for documents. It also provides an implementation of the [embeddings::EmbeddingsBuilder] +//! generate embeddings for documents. It also provides an implementation of the [crate::embeddings::EmbeddingsBuilder] //! struct, which allows users to build collections of document embeddings using different embedding //! models and document sources. //! diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 6997c6ff..a4850791 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -52,8 +52,8 @@ //! //! ## Vector stores and indexes //! Rig provides a common interface for working with vector stores and indexes. Specifically, the library -//! provides the [VectorStore](crate::vector_store::VectorStore) and [VectorStoreIndex](crate::vector_store::VectorStoreIndex) -//! traits, which can be implemented to define vector stores and indices respectively. +//! provides the [VectorStoreIndex](crate::vector_store::VectorStoreIndex) +//! trait, which can be implemented to define vector stores and indices. //! Those can then be used as the knowledgebase for a [RagAgent](crate::rag::RagAgent), or //! as a source of context documents in a custom architecture that use multiple LLMs or agents. //! @@ -72,14 +72,14 @@ pub mod completion; pub mod embeddings; pub mod extractor; pub mod json_utils; +pub mod one_or_many; pub mod providers; pub mod tool; -pub mod vec_utils; pub mod vector_store; // Re-export commonly used types and traits pub use embeddings::embeddable::Embeddable; -pub use vec_utils::OneOrMany; +pub use one_or_many::OneOrMany; #[cfg(feature = "derive")] pub use rig_derive::Embeddable; diff --git a/rig-core/src/vec_utils.rs b/rig-core/src/one_or_many.rs similarity index 59% rename from rig-core/src/vec_utils.rs rename to rig-core/src/one_or_many.rs index a487e71c..23ece94f 100644 --- a/rig-core/src/vec_utils.rs +++ b/rig-core/src/one_or_many.rs @@ -33,14 +33,6 @@ impl<T: Clone> OneOrMany<T> { self.rest.clone() } - /// Use the Iterator trait on OneOrMany - pub fn iter(&self) -> OneOrManyIterator<T> { - OneOrManyIterator { - one_or_many: self, - index: 0, - } - } - /// Create a OneOrMany object with a single item of any type. pub fn one(item: T) -> Self { OneOrMany { @@ -70,41 +62,102 @@ impl<T: Clone> OneOrMany<T> { OneOrMany::many(items) } + + pub fn iter(&self) -> Iter<T> { + Iter { + first: Some(&self.first), + rest: self.rest.iter(), + } + } + + pub fn iter_mut(&mut self) -> IterMut<'_, T> { + IterMut { + first: Some(&mut self.first), + rest: self.rest.iter_mut(), + } + } } -/// Implement Iterator for OneOrMany. -/// Iterates over all items in both `first` and `rest`. -/// Borrows the OneOrMany object that is being iterator over. -pub struct OneOrManyIterator<'a, T> { - one_or_many: &'a OneOrMany<T>, - index: usize, +// ================================================================ +// Implementations of Iterator for OneOrMany +// - OneOrMany<T>::iter() -> iterate over references of T objects +// - OneOrMany<T>::into_iter() -> iterate over owned T objects +// - OneOrMany<T>::iter_mut() -> iterate over mutable references of T objects +// ================================================================ + +/// Struct returned by call to `OneOrMany::iter()`. +pub struct Iter<'a, T> { + // References. + first: Option<&'a T>, + rest: std::slice::Iter<'a, T>, } -impl<'a, T> Iterator for OneOrManyIterator<'a, T> { +/// Implement `Iterator` for `Iter<T>`. +/// The Item type of the `Iterator` trait is a reference of `T`. +impl<'a, T> Iterator for Iter<'a, T> { type Item = &'a T; fn next(&mut self) -> Option<Self::Item> { - let mut item = None; - if self.index == 0 { - item = Some(&self.one_or_many.first) - } else if self.index - 1 < self.one_or_many.rest.len() { - item = Some(&self.one_or_many.rest[self.index - 1]); - }; - - self.index += 1; - item + if let Some(first) = self.first.take() { + Some(first) + } else { + self.rest.next() + } } } -/// Implement IntoIterator for OneOrMany. -/// Iterates over all items in both `first` and `rest`. -/// Takes ownership the OneOrMany object that is being iterator over. +/// Struct returned by call to `OneOrMany::into_iter()`. +pub struct IntoIter<T> { + // Owned. + first: Option<T>, + rest: std::vec::IntoIter<T>, +} + +/// Implement `Iterator` for `IntoIter<T>`. impl<T: Clone> IntoIterator for OneOrMany<T> { type Item = T; - type IntoIter = std::iter::Chain<std::iter::Once<T>, std::vec::IntoIter<T>>; + type IntoIter = IntoIter<T>; fn into_iter(self) -> Self::IntoIter { - std::iter::once(self.first).chain(self.rest) + IntoIter { + first: Some(self.first), + rest: self.rest.into_iter(), + } + } +} + +/// Implement `Iterator` for `IntoIter<T>`. +/// The Item type of the `Iterator` trait is an owned `T`. +impl<T: Clone> Iterator for IntoIter<T> { + type Item = T; + + fn next(&mut self) -> Option<Self::Item> { + if let Some(first) = self.first.take() { + Some(first) + } else { + self.rest.next() + } + } +} + +/// Struct returned by call to `OneOrMany::iter_mut()`. +pub struct IterMut<'a, T> { + // Mutable references. + first: Option<&'a mut T>, + rest: std::slice::IterMut<'a, T>, +} + +// Implement `Iterator` for `IterMut<T>`. +// The Item type of the `Iterator` trait is a mutable reference of `OneOrMany<T>`. +impl<'a, T> Iterator for IterMut<'a, T> { + type Item = &'a mut T; + + fn next(&mut self) -> Option<Self::Item> { + if let Some(first) = self.first.take() { + Some(first) + } else { + self.rest.next() + } } } @@ -113,7 +166,7 @@ mod test { use super::OneOrMany; #[test] - fn test_one_or_many_iter_single() { + fn test_single() { let one_or_many = OneOrMany::one("hello".to_string()); assert_eq!(one_or_many.iter().count(), 1); @@ -124,7 +177,7 @@ mod test { } #[test] - fn test_one_or_many_iter() { + fn test() { let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); assert_eq!(one_or_many.iter().count(), 2); @@ -189,6 +242,34 @@ mod test { }); } + #[test] + fn test_mut_single() { + let mut one_or_many = OneOrMany::one("hello".to_string()); + + assert_eq!(one_or_many.iter_mut().count(), 1); + + one_or_many.iter_mut().for_each(|i| { + assert_eq!(i, "hello"); + }); + } + + #[test] + fn test_mut() { + let mut one_or_many = + OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + assert_eq!(one_or_many.iter_mut().count(), 2); + + one_or_many.iter_mut().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello"); + } + if i == 1 { + assert_eq!(item, "word"); + } + }); + } + #[test] fn test_one_or_many_error() { assert!(OneOrMany::<String>::many(vec![]).is_err()) From 4d2ffdb0f8a8c6e425a99028c62816f7bbd5ff3f Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 09:57:17 -0400 Subject: [PATCH 51/91] fix: update borrow and owning of macro --- rig-core/rig-core-derive/src/basic.rs | 22 +++++++++------------- rig-core/rig-core-derive/src/custom.rs | 10 ++++------ rig-core/rig-core-derive/src/embeddable.rs | 4 ++-- rig-core/src/one_or_many.rs | 3 ++- rig-core/tests/embeddable_macro.rs | 19 ------------------- 5 files changed, 17 insertions(+), 41 deletions(-) diff --git a/rig-core/rig-core-derive/src/basic.rs b/rig-core/rig-core-derive/src/basic.rs index 942e6c7e..86bb13ad 100644 --- a/rig-core/rig-core-derive/src/basic.rs +++ b/rig-core/rig-core-derive/src/basic.rs @@ -3,19 +3,15 @@ use syn::{parse_quote, Attribute, DataStruct, Meta}; use crate::EMBED; /// Finds and returns fields with simple #[embed] attribute tags only. -pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item = syn::Field> { - data_struct.fields.clone().into_iter().filter(|field| { - field - .attrs - .clone() - .into_iter() - .any(|attribute| match attribute { - Attribute { - meta: Meta::Path(path), - .. - } => path.is_ident(EMBED), - _ => false, - }) +pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item = &syn::Field> { + data_struct.fields.iter().filter(|field| { + field.attrs.iter().any(|attribute| match attribute { + Attribute { + meta: Meta::Path(path), + .. + } => path.is_ident(EMBED), + _ => false, + }) }) } diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs index de754372..f29ed20e 100644 --- a/rig-core/rig-core-derive/src/custom.rs +++ b/rig-core/rig-core-derive/src/custom.rs @@ -9,19 +9,17 @@ const EMBED_WITH: &str = "embed_with"; /// Also returns the "..." part of the tag (ie. the custom function). pub(crate) fn custom_embed_fields( data_struct: &syn::DataStruct, -) -> syn::Result<Vec<(syn::Field, syn::ExprPath)>> { +) -> syn::Result<Vec<(&syn::Field, syn::ExprPath)>> { data_struct .fields - .clone() - .into_iter() + .iter() .filter_map(|field| { field .attrs - .clone() - .into_iter() + .iter() .filter_map(|attribute| match attribute.is_custom() { Ok(true) => match attribute.expand_tag() { - Ok(path) => Some(Ok((field.clone(), path))), + Ok(path) => Some(Ok((field, path))), Err(e) => Some(Err(e)), }, Ok(false) => None, diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index 13299ba6..8336c9e0 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -76,7 +76,7 @@ impl StructParser for DataStruct { .map(|field| { add_struct_bounds(generics, &field.ty); - let field_name = field.ident; + let field_name = &field.ident; quote! { self.#field_name @@ -106,7 +106,7 @@ impl StructParser for DataStruct { // Iterate over every field tagged with #[embed(embed_with = "...")] .into_iter() .map(|(field, custom_func_path)| { - let field_name = field.ident; + let field_name = &field.ident; quote! { #custom_func_path(self.#field_name.clone()) diff --git a/rig-core/src/one_or_many.rs b/rig-core/src/one_or_many.rs index 23ece94f..08165873 100644 --- a/rig-core/src/one_or_many.rs +++ b/rig-core/src/one_or_many.rs @@ -262,7 +262,8 @@ mod test { one_or_many.iter_mut().enumerate().for_each(|(i, item)| { if i == 0 { - assert_eq!(item, "hello"); + item.push_str(" world"); + assert_eq!(item, "hello world"); } if i == 1 { assert_eq!(item, "word"); diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/embeddable_macro.rs index d30c5bfa..d3f93966 100644 --- a/rig-core/tests/embeddable_macro.rs +++ b/rig-core/tests/embeddable_macro.rs @@ -35,11 +35,6 @@ fn test_custom_embed() { }, }; - println!( - "FakeDefinition: {}, {}", - fake_definition.id, fake_definition.word - ); - assert_eq!( fake_definition.embeddable().unwrap(), OneOrMany::one( @@ -70,11 +65,6 @@ fn test_custom_and_basic_embed() { }, }; - println!( - "FakeDefinition: {}, {}", - fake_definition.id, fake_definition.word - ); - assert_eq!( fake_definition.embeddable().unwrap().first(), "house".to_string() @@ -104,11 +94,6 @@ fn test_single_embed() { definition: definition.clone(), }; - println!( - "FakeDefinition3: {}, {}", - fake_definition.id, fake_definition.word - ); - assert_eq!( fake_definition.embeddable().unwrap(), OneOrMany::one(definition) @@ -131,8 +116,6 @@ fn test_multiple_embed_strings() { employee_ages: vec![25, 30, 35, 40], }; - println!("Company: {}, {}", company.id, company.company); - let result = company.embeddable().unwrap(); assert_eq!( @@ -171,8 +154,6 @@ fn test_multiple_embed_tags() { employee_ages: vec![25, 30, 35, 40], }; - println!("Company2: {}", company.id); - assert_eq!( company.embeddable().unwrap(), OneOrMany::many(vec![ From 6f0422567d1d206e1cf50f8649f68043dd720d73 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 10:24:04 -0400 Subject: [PATCH 52/91] clippy: add back print statements --- rig-core/tests/embeddable_macro.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/embeddable_macro.rs index d3f93966..cbc76c80 100644 --- a/rig-core/tests/embeddable_macro.rs +++ b/rig-core/tests/embeddable_macro.rs @@ -35,6 +35,11 @@ fn test_custom_embed() { }, }; + println!( + "FakeDefinition: {}, {}", + fake_definition.id, fake_definition.word + ); + assert_eq!( fake_definition.embeddable().unwrap(), OneOrMany::one( @@ -65,6 +70,11 @@ fn test_custom_and_basic_embed() { }, }; + println!( + "FakeDefinition: {}, {}", + fake_definition.id, fake_definition.word + ); + assert_eq!( fake_definition.embeddable().unwrap().first(), "house".to_string() @@ -93,6 +103,10 @@ fn test_single_embed() { word: "house".to_string(), definition: definition.clone(), }; + println!( + "FakeDefinition3: {}, {}", + fake_definition.id, fake_definition.word + ); assert_eq!( fake_definition.embeddable().unwrap(), @@ -116,6 +130,8 @@ fn test_multiple_embed_strings() { employee_ages: vec![25, 30, 35, 40], }; + println!("Company: {}, {}", company.id, company.company); + let result = company.embeddable().unwrap(); assert_eq!( @@ -154,6 +170,8 @@ fn test_multiple_embed_tags() { employee_ages: vec![25, 30, 35, 40], }; + println!("Company: {}", company.id); + assert_eq!( company.embeddable().unwrap(), OneOrMany::many(vec![ From 5897e22bd2f91d7afff6806c28ec2d27b9d47961 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 11:47:14 -0400 Subject: [PATCH 53/91] fix: fix issues caused by merge of derive macro branch --- rig-core/Cargo.toml | 14 +- rig-core/examples/calculator_chatbot.rs | 13 +- rig-core/examples/rag.rs | 54 +++- rig-core/examples/rag_dynamic_tools.rs | 9 +- rig-core/examples/vector_search.rs | 62 +++- rig-core/examples/vector_search_cohere.rs | 65 +++- rig-core/src/embeddings/builder.rs | 290 +++++++----------- rig-core/src/embeddings/mod.rs | 1 + rig-core/src/embeddings/tool.rs | 18 +- rig-core/src/one_or_many.rs | 5 + rig-core/src/providers/cohere.rs | 4 +- rig-core/src/providers/openai.rs | 4 +- rig-core/src/tool.rs | 18 +- rig-lancedb/Cargo.toml | 1 + rig-lancedb/examples/fixtures/lib.rs | 56 ++-- .../examples/vector_search_local_ann.rs | 35 +-- .../examples/vector_search_local_enn.rs | 9 +- rig-lancedb/examples/vector_search_s3_ann.rs | 33 +- rig-mongodb/Cargo.toml | 2 + rig-mongodb/examples/vector_search_mongodb.rs | 83 ++++- rig-mongodb/src/lib.rs | 40 ++- 21 files changed, 460 insertions(+), 356 deletions(-) diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index bc65409a..ea910406 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -35,4 +35,16 @@ derive = ["dep:rig-derive"] [[test]] name = "embeddable_macro" -required-features = ["derive"] \ No newline at end of file +required-features = ["derive"] + +[[example]] +name = "rag" +required-features = ["derive"] + +[[example]] +name = "vector_search" +required-features = ["derive"] + +[[example]] +name = "vector_search_cohere" +required-features = ["derive"] \ No newline at end of file diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 4c8601d6..30c3741f 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -2,8 +2,7 @@ use anyhow::Result; use rig::{ cli_chatbot::cli_chatbot, completion::ToolDefinition, - embeddings::builder::DocumentEmbeddings, - embeddings::EmbeddingsBuilder, + embeddings::builder::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, @@ -26,7 +25,7 @@ struct MathError; #[error("Init error")] struct InitError; -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Clone)] struct Add; impl Tool for Add { const NAME: &'static str = "add"; @@ -78,7 +77,7 @@ impl ToolEmbedding for Add { fn context(&self) -> Self::Context {} } -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Clone)] struct Subtract; impl Tool for Subtract { const NAME: &'static str = "subtract"; @@ -248,7 +247,7 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .tools(&toolset)? + .documents(toolset.embedabble_tools()?)? .build() .await?; @@ -257,7 +256,7 @@ async fn main() -> Result<(), anyhow::Error> { embeddings .into_iter() .enumerate() - .map(|(i, (tool, embedding))| (i.to_string(), tool, embedding)) + .map(|(i, (tool, embedding))| (i.to_string(), tool, vec![embedding])) .collect(), )? .index(embedding_model); @@ -287,4 +286,4 @@ async fn main() -> Result<(), anyhow::Error> { cli_chatbot(calculator_rag).await?; Ok(()) -} +} \ No newline at end of file diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 674d028c..4cf6b05a 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -1,12 +1,22 @@ -use std::env; +use std::{env, vec}; use rig::{ completion::Prompt, - embeddings::builder::DocumentEmbeddings, - embeddings::EmbeddingsBuilder, + embeddings::builder::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, + Embeddable, }; +use serde::Serialize; + +// Shape of data that needs to be RAG'ed. +// The definition field will be used to generate embeddings. +#[derive(Embeddable, Clone, Debug, Serialize, Eq, PartialEq, Default)] +struct FakeDefinition { + id: String, + #[embed] + definitions: Vec<String>, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -17,9 +27,29 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .documents(vec![ + FakeDefinition { + id: "doc0".to_string(), + definitions: vec![ + "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets.".to_string(), + "Definition of a *flurbo*: A fictional digital currency that originated in the animated series Rick and Morty.".to_string() + ] + }, + FakeDefinition { + id: "doc1".to_string(), + definitions: vec![ + "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "Definition of a *glarb-glarb*: A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + ] + }, + FakeDefinition { + id: "doc2".to_string(), + definitions: vec![ + "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + "Definition of a *linglingdong*: A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() + ] + }, + ])? .build() .await?; @@ -27,13 +57,9 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .map( - |DocumentEmbeddings { - id, - document, - embeddings, - }| { (id, document, embeddings) }, - ) + .map(|(fake_definition, embedding_vec)| { + (fake_definition.id.clone(), fake_definition, embedding_vec) + }) .collect(), )? .index(embedding_model); @@ -52,4 +78,4 @@ async fn main() -> Result<(), anyhow::Error> { println!("{}", response); Ok(()) -} +} \ No newline at end of file diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 97938614..0ddc210b 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -1,8 +1,7 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, - embeddings::builder::DocumentEmbeddings, - embeddings::EmbeddingsBuilder, + embeddings::builder::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, @@ -157,7 +156,7 @@ async fn main() -> Result<(), anyhow::Error> { .build(); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .tools(&toolset)? + .documents(toolset.embedabble_tools()?)? .build() .await?; @@ -166,7 +165,7 @@ async fn main() -> Result<(), anyhow::Error> { embeddings .into_iter() .enumerate() - .map(|(i, (tool, embedding))| (i.to_string(), tool, embedding)) + .map(|(i, (tool, embedding))| (i.to_string(), tool, vec![embedding])) .collect(), )? .index(embedding_model); @@ -185,4 +184,4 @@ async fn main() -> Result<(), anyhow::Error> { println!("{}", response); Ok(()) -} +} \ No newline at end of file diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index feadbfb4..0766c99a 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,11 +1,22 @@ use std::env; use rig::{ - embeddings::builder::DocumentEmbeddings, - embeddings::EmbeddingsBuilder, + embeddings::builder::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, + Embeddable, }; +use serde::{Deserialize, Serialize}; + +// Shape of data that needs to be RAG'ed. +// The definition field will be used to generate embeddings. +#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +struct FakeDefinition { + id: String, + word: String, + #[embed] + definitions: Vec<String>, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -16,9 +27,32 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .documents(vec![ + FakeDefinition { + id: "doc0".to_string(), + word: "flurbo".to_string(), + definitions: vec![ + "A green alien that lives on cold planets.".to_string(), + "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() + ] + }, + FakeDefinition { + id: "doc1".to_string(), + word: "glarb-glarb".to_string(), + definitions: vec![ + "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + ] + }, + FakeDefinition { + id: "doc2".to_string(), + word: "linglingdong".to_string(), + definitions: vec![ + "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), + "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() + ] + }, + ])? .build() .await?; @@ -26,28 +60,24 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .map( - |DocumentEmbeddings { - id, - document, - embeddings, - }| { (id, document, embeddings) }, - ) + .map(|(fake_definition, embedding_vec)| { + (fake_definition.id.clone(), fake_definition, embedding_vec) + }) .collect(), )? .index(model); let results = index - .top_n::<String>("What is a linglingdong?", 1) + .top_n::<FakeDefinition>("I need to buy something in a fictional universe. What type of money can I use for this?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc)) + .map(|(score, id, doc)| (score, id, doc.word)) .collect::<Vec<_>>(); println!("Results: {:?}", results); let id_results = index - .top_n_ids("What is a linglingdong?", 1) + .top_n_ids("I need to buy something in a fictional universe. What type of money can I use for this?", 1) .await? .into_iter() .map(|(score, id)| (score, id)) @@ -56,4 +86,4 @@ async fn main() -> Result<(), anyhow::Error> { println!("ID results: {:?}", id_results); Ok(()) -} +} \ No newline at end of file diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 6b93bcdd..6524ceab 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,11 +1,22 @@ use std::env; use rig::{ - embeddings::builder::DocumentEmbeddings, - embeddings::EmbeddingsBuilder, + embeddings::builder::EmbeddingsBuilder, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, + Embeddable, }; +use serde::{Deserialize, Serialize}; + +// Shape of data that needs to be RAG'ed. +// The definition field will be used to generate embeddings. +#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +struct FakeDefinition { + id: String, + word: String, + #[embed] + definitions: Vec<String>, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -16,10 +27,33 @@ async fn main() -> Result<(), anyhow::Error> { let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document"); let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query"); - let embeddings = EmbeddingsBuilder::new(document_model) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + let embeddings = EmbeddingsBuilder::new(document_model.clone()) + .documents(vec![ + FakeDefinition { + id: "doc0".to_string(), + word: "flurbo".to_string(), + definitions: vec![ + "A green alien that lives on cold planets.".to_string(), + "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() + ] + }, + FakeDefinition { + id: "doc1".to_string(), + word: "glarb-glarb".to_string(), + definitions: vec![ + "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + ] + }, + FakeDefinition { + id: "doc2".to_string(), + word: "linglingdong".to_string(), + definitions: vec![ + "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), + "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() + ] + }, + ])? .build() .await?; @@ -27,25 +61,24 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .map( - |DocumentEmbeddings { - id, - document, - embeddings, - }| { (id, document, embeddings) }, - ) + .map(|(fake_definition, embedding_vec)| { + (fake_definition.id.clone(), fake_definition, embedding_vec) + }) .collect(), )? .index(search_model); let results = index - .top_n::<String>("What is a linglingdong?", 1) + .top_n::<FakeDefinition>( + "Which instrument is found in the Nebulon Mountain Ranges?", + 1, + ) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc)) + .map(|(score, id, doc)| (score, id, doc.word)) .collect::<Vec<_>>(); println!("Results: {:?}", results); Ok(()) -} +} \ No newline at end of file diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 4582f1e3..d7024e8e 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -1,25 +1,63 @@ -//! The module provides an implementation of the [EmbeddingsBuilder] -//! struct, which allows users to build collections of document embeddings using different embedding -//! models and document sources. +//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. +//! Only types that implement the [Embeddable] trait can be added to the [EmbeddingsBuilder]. //! //! # Example //! ```rust -//! use rig::providers::openai::{Client, self}; -//! use rig::embeddings::{EmbeddingModel, EmbeddingsBuilder}; +//! use std::env; //! -//! // Initialize the OpenAI client -//! let openai = Client::new("your-openai-api-key"); +//! use rig::{ +//! embeddings::builder::EmbeddingsBuilder, +//! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, +//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, +//! Embeddable, +//! }; +//! use serde::{Deserialize, Serialize}; //! -//! // Create an instance of the `text-embedding-ada-002` model -//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002); +//! // Shape of data that needs to be RAG'ed. +//! // The definition field will be used to generate embeddings. +//! #[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +//! struct FakeDefinition { +//! id: String, +//! word: String, +//! #[embed] +//! definitions: Vec<String>, +//! } //! -//! // Create an embeddings builder and add documents -//! let embeddings = EmbeddingsBuilder::new(embedding_model) -//! .simple_document("doc1", "This is the first document.") -//! .simple_document("doc2", "This is the second document.") +//! // Create OpenAI client +//! let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); +//! let openai_client = Client::new(&openai_api_key); +//! +//! let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); +//! +//! let embeddings = EmbeddingsBuilder::new(model.clone()) +//! .documents(vec![ +//! FakeDefinition { +//! id: "doc0".to_string(), +//! word: "flurbo".to_string(), +//! definitions: vec![ +//! "A green alien that lives on cold planets.".to_string(), +//! "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() +//! ] +//! }, +//! FakeDefinition { +//! id: "doc1".to_string(), +//! word: "glarb-glarb".to_string(), +//! definitions: vec![ +//! "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), +//! "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() +//! ] +//! }, +//! FakeDefinition { +//! id: "doc2".to_string(), +//! word: "linglingdong".to_string(), +//! definitions: vec![ +//! "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), +//! "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() +//! ] +//! }, +//! ])? //! .build() -//! .await -//! .expect("Failed to build embeddings."); +//! .await?; //! //! // Use the generated embeddings //! // ... @@ -28,39 +66,19 @@ use std::{cmp::max, collections::HashMap}; use futures::{stream, StreamExt, TryStreamExt}; -use serde::{Deserialize, Serialize}; use crate::{ - embeddings::{Embedding, EmbeddingError, EmbeddingModel}, - tool::{ToolEmbedding, ToolSet, ToolType}, + embeddings::{Embeddable, Embedding, EmbeddingError, EmbeddingModel}, + OneOrMany, }; -/// Struct that holds a document and its embeddings. -/// -/// The struct is designed to model any kind of documents that can be serialized to JSON -/// (including a simple string). -/// -/// Moreover, it can hold multiple embeddings for the same document, thus allowing a -/// large document to be retrieved from a query that matches multiple smaller and -/// distinct text documents. For example, if the document is a textbook, a summary of -/// each chapter could serve as the book's embeddings. -#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)] -pub struct DocumentEmbeddings { - #[serde(rename = "_id")] - pub id: String, - pub document: serde_json::Value, - pub embeddings: Vec<Embedding>, -} - -type Embeddings = Vec<DocumentEmbeddings>; - -/// Builder for creating a collection of embeddings -pub struct EmbeddingsBuilder<M: EmbeddingModel> { +/// Builder for creating a collection of embeddings. +pub struct EmbeddingsBuilder<M: EmbeddingModel, D: Embeddable> { model: M, - documents: Vec<(String, serde_json::Value, Vec<String>)>, + documents: Vec<(D, OneOrMany<String>)>, } -impl<M: EmbeddingModel> EmbeddingsBuilder<M> { +impl<M: EmbeddingModel, D: Embeddable> EmbeddingsBuilder<M, D> { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { @@ -69,169 +87,81 @@ impl<M: EmbeddingModel> EmbeddingsBuilder<M> { } } - /// Add a simple document to the embedding collection. - /// The provided document string will be used for the embedding. - pub fn simple_document(mut self, id: &str, document: &str) -> Self { - self.documents.push(( - id.to_string(), - serde_json::Value::String(document.to_string()), - vec![document.to_string()], - )); - self - } + /// Add a document that implements `Embeddable` to the builder. + pub fn document(mut self, document: D) -> Result<Self, D::Error> { + let embed_targets = document.embeddable()?; - /// Add multiple documents to the embedding collection. - /// Each element of the vector is a tuple of the form (id, document). - pub fn simple_documents(mut self, documents: Vec<(String, String)>) -> Self { - self.documents - .extend(documents.into_iter().map(|(id, document)| { - ( - id, - serde_json::Value::String(document.clone()), - vec![document], - ) - })); - self - } - - /// Add a tool to the embedding collection. - /// The `tool.context()` corresponds to the document being stored while - /// `tool.embedding_docs()` corresponds to the documents that will be used to generate the embeddings. - pub fn tool(mut self, tool: impl ToolEmbedding + 'static) -> Result<Self, EmbeddingError> { - self.documents.push(( - tool.name(), - serde_json::to_value(tool.context())?, - tool.embedding_docs(), - )); - Ok(self) - } - - /// Add the tools from the given toolset to the embedding collection. - pub fn tools(mut self, toolset: &ToolSet) -> Result<Self, EmbeddingError> { - for (name, tool) in toolset.tools.iter() { - if let ToolType::Embedding(tool) = tool { - self.documents.push(( - name.clone(), - tool.context().map_err(|e| { - EmbeddingError::DocumentError(format!( - "Failed to generate context for tool {}: {}", - name, e - )) - })?, - tool.embedding_docs(), - )); - } - } + self.documents.push((document, embed_targets)); Ok(self) } - /// Add a document to the embedding collection. - /// `embed_documents` are the documents that will be used to generate the embeddings - /// for `document`. - pub fn document<T: Serialize>( - mut self, - id: &str, - document: T, - embed_documents: Vec<String>, - ) -> Self { - self.documents.push(( - id.to_string(), - serde_json::to_value(document).expect("Document should serialize"), - embed_documents, - )); - self - } - - /// Add multiple documents to the embedding collection. - /// Each element of the vector is a tuple of the form (id, document, embed_documents). - pub fn documents<T: Serialize>(mut self, documents: Vec<(String, T, Vec<String>)>) -> Self { - self.documents.extend( - documents - .into_iter() - .map(|(id, document, embed_documents)| { - ( - id, - serde_json::to_value(document).expect("Document should serialize"), - embed_documents, - ) - }), - ); - self - } + /// Add many documents that implement `Embeddable` to the builder. + pub fn documents(mut self, documents: Vec<D>) -> Result<Self, D::Error> { + for doc in documents.into_iter() { + let embed_targets = doc.embeddable()?; - /// Add a json document to the embedding collection. - pub fn json_document( - mut self, - id: &str, - document: serde_json::Value, - embed_documents: Vec<String>, - ) -> Self { - self.documents - .push((id.to_string(), document, embed_documents)); - self - } + self.documents.push((doc, embed_targets)); + } - /// Add multiple json documents to the embedding collection. - pub fn json_documents( - mut self, - documents: Vec<(String, serde_json::Value, Vec<String>)>, - ) -> Self { - self.documents.extend(documents); - self + Ok(self) } +} - /// Generate the embeddings for the given documents - pub async fn build(self) -> Result<Embeddings, EmbeddingError> { - // Create a temporary store for the documents +impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M, D> { + /// Generate embeddings for all documents in the builder. + /// The method only applies when documents in the builder each contain multiple embedding targets. + /// Returns a vector of tuples, where the first element is the document and the second element is the vector of embeddings. + pub async fn build(&self) -> Result<Vec<(D, OneOrMany<Embedding>)>, EmbeddingError> { + // Use this for reference later to merge a document back with its embeddings. let documents_map = self .documents + .clone() .into_iter() - .map(|(id, document, docs)| (id, (document, docs))) + .enumerate() + .map(|(id, (document, _))| (id, document)) .collect::<HashMap<_, _>>(); - let embeddings = stream::iter(documents_map.iter()) - // Flatten the documents - .flat_map(|(id, (_, docs))| { - stream::iter(docs.iter().map(|doc| (id.clone(), doc.clone()))) + let embeddings = stream::iter(self.documents.iter().enumerate()) + // Merge the embedding targets of each document into a single list of embedding targets. + .flat_map(|(i, (_, embed_targets))| { + stream::iter(embed_targets.clone().into_iter().map(move |target| (i, target))) }) - // Chunk them into N (the embedding API limit per request). + // Chunk them into N (the emebdding API limit per request). .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings + // Generate the embeddings for a chunk at a time. .map(|docs| async { - let (ids, docs): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + let (document_indices, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + Ok::<_, EmbeddingError>( - ids.into_iter() - .zip(self.model.embed_documents(docs).await?.into_iter()) + document_indices + .into_iter() + .zip(self.model.embed_documents(embed_targets).await?.into_iter()) .collect::<Vec<_>>(), ) }) .boxed() // Parallelize the embeddings generation over 10 concurrent requests .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - .try_fold(vec![], |mut acc, mut embeddings| async move { - Ok({ - acc.append(&mut embeddings); - acc - }) - }) - .await?; - - // Assemble the DocumentEmbeddings - let mut document_embeddings: HashMap<String, DocumentEmbeddings> = HashMap::new(); - embeddings.into_iter().for_each(|(id, embedding)| { - let (document, _) = documents_map.get(&id).expect("Document not found"); - let document_embedding = - document_embeddings - .entry(id.clone()) - .or_insert_with(|| DocumentEmbeddings { - id: id.clone(), - document: document.clone(), - embeddings: vec![], + .try_fold( + HashMap::new(), + |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move { + embeddings.into_iter().for_each(|(i, embedding)| { + acc.entry(i).or_insert(OneOrMany::one(embedding.clone())).add(embedding.clone()); }); - document_embedding.embeddings.push(embedding); - }); + Ok(acc) + }, + ) + .await? + .iter() + .fold(vec![], |mut acc, (i, embeddings_vec)| { + acc.push(( + documents_map.get(i).cloned().unwrap(), + embeddings_vec.clone(), + )); + acc + }); - Ok(document_embeddings.into_values().collect()) + Ok(embeddings) } -} +} \ No newline at end of file diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index a9eda7c3..763e0f30 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -6,6 +6,7 @@ pub mod builder; pub mod embeddable; pub mod embedding; +pub mod tool; pub use builder::EmbeddingsBuilder; pub use embeddable::Embeddable; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index 278efb20..0769c82c 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -1,26 +1,30 @@ -use crate::{self as rig, tool::ToolEmbeddingDyn}; -use rig::embeddings::embeddable::Embeddable; -use rig_derive::Embeddable; +use crate::{tool::ToolEmbeddingDyn, Embeddable, OneOrMany}; use serde::Serialize; use super::embeddable::EmbeddableError; /// Used by EmbeddingsBuilder to embed anything that implements ToolEmbedding. -#[derive(Embeddable, Clone, Serialize, Default, Eq, PartialEq)] +#[derive(Clone, Serialize, Default, Eq, PartialEq)] pub struct EmbeddableTool { name: String, context: serde_json::Value, - #[embed] embedding_docs: Vec<String> } +impl Embeddable for EmbeddableTool { + type Error = EmbeddableError; + + fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { + OneOrMany::many(self.embedding_docs.clone()).map_err(EmbeddableError::new) + } +} + impl EmbeddableTool { /// Convert item that implements ToolEmbedding to an EmbeddableTool. pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result<Self, EmbeddableError> { Ok(EmbeddableTool { name: tool.name(), - context: serde_json::to_value(tool.context().map_err(EmbeddableError::SerdeError)?) - .map_err(EmbeddableError::SerdeError)?, + context: tool.context().map_err(EmbeddableError::new)?, embedding_docs: tool.embedding_docs(), }) } diff --git a/rig-core/src/one_or_many.rs b/rig-core/src/one_or_many.rs index 08165873..6a1607be 100644 --- a/rig-core/src/one_or_many.rs +++ b/rig-core/src/one_or_many.rs @@ -33,6 +33,11 @@ impl<T: Clone> OneOrMany<T> { self.rest.clone() } + /// After `OneOrMany<T>` is created, add an item of type T to the `rest`. + pub fn add(&mut self, item: T) { + self.rest.push(item); + } + /// Create a OneOrMany object with a single item of any type. pub fn one(item: T) -> Self { OneOrMany { diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index ae874b21..2d2d5bf5 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -15,7 +15,7 @@ use crate::{ completion::{self, CompletionError}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, + json_utils, Embeddable, }; use schemars::JsonSchema; @@ -85,7 +85,7 @@ impl Client { EmbeddingModel::new(self.clone(), model, input_type, ndims) } - pub fn embeddings(&self, model: &str, input_type: &str) -> EmbeddingsBuilder<EmbeddingModel> { + pub fn embeddings<D: Embeddable>(&self, model: &str, input_type: &str) -> EmbeddingsBuilder<EmbeddingModel, D> { EmbeddingsBuilder::new(self.embedding_model(model, input_type)) } diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index c9ba9afa..b20df22f 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -13,7 +13,7 @@ use crate::{ completion::{self, CompletionError, CompletionRequest}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, + json_utils, Embeddable, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -121,7 +121,7 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel> { + pub fn embeddings<D: Embeddable>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> { EmbeddingsBuilder::new(self.embedding_model(model)) } diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 98394ecf..1104d5fb 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, pin::Pin}; use futures::Future; use serde::{Deserialize, Serialize}; -use crate::completion::{self, ToolDefinition}; +use crate::{completion::{self, ToolDefinition}, embeddings::{embeddable::EmbeddableError, tool::EmbeddableTool}}; #[derive(Debug, thiserror::Error)] pub enum ToolError { @@ -323,6 +323,22 @@ impl ToolSet { } Ok(docs) } + + /// Convert tools in self to objects of type EmbeddableTool. + /// This is necessary because when adding tools to the EmbeddingBuilder because all + /// documents added to the builder must all be of the same type. + pub fn embedabble_tools(&self) -> Result<Vec<EmbeddableTool>, EmbeddableError> { + self.tools + .values() + .filter_map(|tool_type| { + if let ToolType::Embedding(tool) = tool_type { + Some(EmbeddableTool::try_from(&**tool)) + } else { + None + } + }) + .collect::<Result<Vec<_>, _>>() + } } #[derive(Default)] diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index 031df2d3..6ee41d8d 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] lancedb = "0.10.0" rig-core = { path = "../rig-core", version = "0.2.1" } +rig-derive = { path = "../rig-core/rig-core-derive" } arrow-array = "52.2.0" serde_json = "1.0.128" serde = "1.0.210" diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index d95a42e4..415422f0 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -2,13 +2,39 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; -use rig::embeddings::builder::DocumentEmbeddings; +use rig::embeddings::embedding::Embedding; +use rig::Embeddable; +use serde::Deserialize; + +#[derive(Embeddable, Clone, Deserialize, Debug)] +pub struct FakeDefinition { + pub id: String, + #[embed] + pub definition: String, +} + +pub fn fake_definitions() -> Vec<FakeDefinition> { + vec![ + FakeDefinition { + id: "doc0".to_string(), + definition: "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.".to_string() + }, + FakeDefinition { + id: "doc1".to_string(), + definition: "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.".to_string() + }, + FakeDefinition { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + } + ] +} // Schema of table in LanceDB. pub fn schema(dims: usize) -> Schema { Schema::new(Fields::from(vec![ Field::new("id", DataType::Utf8, false), - Field::new("content", DataType::Utf8, false), + Field::new("definition", DataType::Utf8, false), Field::new( "embedding", DataType::FixedSizeList( @@ -20,48 +46,36 @@ pub fn schema(dims: usize) -> Schema { ])) } -// Convert DocumentEmbeddings objects to a RecordBatch. +// Convert FakeDefinition objects and their embedding to a RecordBatch. pub fn as_record_batch( - records: Vec<DocumentEmbeddings>, + records: Vec<(FakeDefinition, Embedding)>, dims: usize, ) -> Result<RecordBatch, lancedb::arrow::arrow_schema::ArrowError> { let id = StringArray::from_iter_values( records .iter() - .flat_map(|record| (0..record.embeddings.len()).map(|i| format!("{}-{i}", record.id))) + .map(|(FakeDefinition { id, .. }, _)| id) .collect::<Vec<_>>(), ); - let content = StringArray::from_iter_values( + let definition = StringArray::from_iter_values( records .iter() - .flat_map(|record| { - record - .embeddings - .iter() - .map(|embedding| embedding.document.clone()) - }) + .map(|(FakeDefinition { definition, .. }, _)| definition) .collect::<Vec<_>>(), ); let embedding = FixedSizeListArray::from_iter_primitive::<Float64Type, _, _>( records .into_iter() - .flat_map(|record| { - record - .embeddings - .into_iter() - .map(|embedding| embedding.vec.into_iter().map(Some).collect::<Vec<_>>()) - .map(Some) - .collect::<Vec<_>>() - }) + .map(|(_, Embedding { vec, .. })| Some(vec.into_iter().map(Some).collect::<Vec<_>>())) .collect::<Vec<_>>(), dims as i32, ); RecordBatch::try_from_iter(vec![ ("id", Arc::new(id) as ArrayRef), - ("content", Arc::new(content) as ArrayRef), + ("definition", Arc::new(definition) as ArrayRef), ("embedding", Arc::new(embedding) as ArrayRef), ]) } diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 6466c2b9..1b7870fb 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,25 +1,18 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema}; +use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::index::vector::IvfPqIndexBuilder; +use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::{EmbeddingModel, EmbeddingsBuilder}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::VectorStoreIndex, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -use serde::Deserialize; #[path = "./fixtures/lib.rs"] mod fixture; -#[derive(Deserialize, Debug)] -pub struct VectorSearchResult { - pub id: String, - pub content: String, -} - #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). @@ -32,18 +25,18 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; - // Set up test data for RAG demo - let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); - - // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. - let definitions = vec![definition; 256]; - // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") - .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") - .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") - .simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect()) + .documents(fake_definitions())? + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + .documents( + (0..256) + .map(|i| FakeDefinition { + id: format!("doc{}", i), + definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() + }) + .collect(), + )? .build() .await?; @@ -72,7 +65,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n::<VectorSearchResult>("My boss says I zindle too much, what does that mean?", 1) + .top_n::<FakeDefinition>("My boss says I zindle too much, what does that mean?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 5932dcd0..630acc1a 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -1,9 +1,9 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema}; +use fixture::{as_record_batch, fake_definitions, schema}; use rig::{ - embeddings::{EmbeddingModel, EmbeddingsBuilder}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, }; @@ -21,10 +21,9 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") - .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") - .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") + .documents(fake_definitions())? .build() .await?; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 70f0c8c5..8c10409b 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,25 +1,18 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema}; +use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::{EmbeddingModel, EmbeddingsBuilder}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -use serde::Deserialize; #[path = "./fixtures/lib.rs"] mod fixture; -#[derive(Deserialize, Debug)] -pub struct VectorSearchResult { - pub id: String, - pub content: String, -} - // Note: see docs to deploy LanceDB on other cloud providers such as google and azure. // https://lancedb.github.io/lancedb/guides/storage/ #[tokio::main] @@ -38,18 +31,18 @@ async fn main() -> Result<(), anyhow::Error> { .execute() .await?; - // Set up test data for RAG demo - let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); - - // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. - let definitions = vec![definition; 256]; - // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") - .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") - .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") - .simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect()) + .documents(fake_definitions())? + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + .documents( + (0..256) + .map(|i| FakeDefinition { + id: format!("doc{}", i), + definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() + }) + .collect(), + )? .build() .await?; @@ -84,7 +77,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n::<VectorSearchResult>("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) + .top_n::<FakeDefinition>("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-mongodb/Cargo.toml b/rig-mongodb/Cargo.toml index 6f313838..78b48892 100644 --- a/rig-mongodb/Cargo.toml +++ b/rig-mongodb/Cargo.toml @@ -13,6 +13,8 @@ repository = "https://github.com/0xPlaygrounds/rig" futures = "0.3.30" mongodb = "2.8.2" rig-core = { path = "../rig-core", version = "0.2.1" } +rig-derive = { path = "../rig-core/rig-core-derive" } + serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.117" tracing = "0.1.40" diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index caba89d8..a816c9c0 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,14 +1,39 @@ -use mongodb::{options::ClientOptions, Client as MongoClient, Collection}; +use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; +use rig::providers::openai::TEXT_EMBEDDING_ADA_002; +use serde::{Deserialize, Serialize}; use std::env; +use rig::Embeddable; use rig::{ - embeddings::builder::DocumentEmbeddings, - embeddings::EmbeddingsBuilder, - providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + embeddings::builder::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; +// Shape of data that needs to be RAG'ed. +// The definition field will be used to generate embeddings. +#[derive(Embeddable, Clone, Deserialize, Debug)] +struct FakeDefinition { + id: String, + #[embed] + definition: String, +} + +#[derive(Clone, Deserialize, Debug, Serialize)] +struct Link { + word: String, + link: String, +} + +// Shape of the document to be stored in MongoDB, with embeddings. +#[derive(Serialize, Debug)] +struct Document { + #[serde(rename = "_id")] + id: String, + definition: String, + embedding: Vec<f64>, +} + #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client @@ -26,37 +51,61 @@ async fn main() -> Result<(), anyhow::Error> { MongoClient::with_options(options).expect("MongoDB client options should be valid"); // Initialize MongoDB vector store - let collection: Collection<DocumentEmbeddings> = mongodb_client + let collection: Collection<Document> = mongodb_client .database("knowledgebase") .collection("context"); // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + let fake_definitions = vec![ + FakeDefinition { + id: "doc0".to_string(), + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + }, + FakeDefinition { + id: "doc1".to_string(), + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + }, + FakeDefinition { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + } + ]; + let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .documents(fake_definitions)? .build() .await?; - match collection.insert_many(embeddings, None).await { + let mongo_documents = embeddings + .iter() + .map( + |(FakeDefinition { id, definition, .. }, embedding)| Document { + id: id.clone(), + definition: definition.clone(), + embedding: embedding.vec.clone(), + }, + ) + .collect::<Vec<_>>(); + + match collection.insert_many(mongo_documents, None).await { Ok(_) => println!("Documents added successfully"), Err(e) => println!("Error adding documents: {:?}", e), - } + }; // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = - MongoDbVectorStore::new(collection).index(model, "vector_index", SearchParams::default()); + let index = MongoDbVectorStore::new(collection).index( + model, + "definitions_vector_index", + SearchParams::new("embedding"), + ); // Query the index let results = index - .top_n::<DocumentEmbeddings>("What is a linglingdong?", 1) - .await? - .into_iter() - .map(|(score, id, doc)| (score, id, doc.document)) - .collect::<Vec<_>>(); + .top_n::<FakeDefinition>("What is a linglingdong?", 1) + .await?; println!("Results: {:?}", results); diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index c3973092..4778e454 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -2,24 +2,23 @@ use futures::StreamExt; use mongodb::bson::{self, doc}; use rig::{ - embeddings::builder::DocumentEmbeddings, - embeddings::{Embedding, EmbeddingModel}, + embeddings::embedding::{Embedding, EmbeddingModel}, vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; /// A MongoDB vector store. -pub struct MongoDbVectorStore { - collection: mongodb::Collection<DocumentEmbeddings>, +pub struct MongoDbVectorStore<C> { + collection: mongodb::Collection<C>, } fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } -impl MongoDbVectorStore { +impl<C> MongoDbVectorStore<C> { /// Create a new `MongoDbVectorStore` from a MongoDB collection. - pub fn new(collection: mongodb::Collection<DocumentEmbeddings>) -> Self { + pub fn new(collection: mongodb::Collection<C>) -> Self { Self { collection } } @@ -32,20 +31,20 @@ impl MongoDbVectorStore { model: M, index_name: &str, search_params: SearchParams, - ) -> MongoDbVectorIndex<M> { + ) -> MongoDbVectorIndex<M, C> { MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params) } } /// A vector index for a MongoDB collection. -pub struct MongoDbVectorIndex<M: EmbeddingModel> { - collection: mongodb::Collection<DocumentEmbeddings>, +pub struct MongoDbVectorIndex<M: EmbeddingModel, C> { + collection: mongodb::Collection<C>, model: M, index_name: String, search_params: SearchParams, } -impl<M: EmbeddingModel> MongoDbVectorIndex<M> { +impl<M: EmbeddingModel, C> MongoDbVectorIndex<M, C> { /// Vector search stage of aggregation pipeline of mongoDB collection. /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex. fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document { @@ -53,12 +52,13 @@ impl<M: EmbeddingModel> MongoDbVectorIndex<M> { filter, exact, num_candidates, + path, } = &self.search_params; doc! { "$vectorSearch": { "index": &self.index_name, - "path": "embeddings.vec", + "path": path, "queryVector": &prompt_embedding.vec, "numCandidates": num_candidates.unwrap_or((n * 10) as u32), "limit": n as u32, @@ -79,9 +79,9 @@ impl<M: EmbeddingModel> MongoDbVectorIndex<M> { } } -impl<M: EmbeddingModel> MongoDbVectorIndex<M> { +impl<M: EmbeddingModel, C> MongoDbVectorIndex<M, C> { pub fn new( - collection: mongodb::Collection<DocumentEmbeddings>, + collection: mongodb::Collection<C>, model: M, index_name: &str, search_params: SearchParams, @@ -99,17 +99,19 @@ impl<M: EmbeddingModel> MongoDbVectorIndex<M> { /// on each of the fields pub struct SearchParams { filter: mongodb::bson::Document, + path: String, exact: Option<bool>, num_candidates: Option<u32>, } impl SearchParams { /// Initializes a new `SearchParams` with default values. - pub fn new() -> Self { + pub fn new(path: &str) -> Self { Self { filter: doc! {}, exact: None, num_candidates: None, + path: path.to_string(), } } @@ -139,13 +141,9 @@ impl SearchParams { } } -impl Default for SearchParams { - fn default() -> Self { - Self::new() - } -} - -impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for MongoDbVectorIndex<M> { +impl<M: EmbeddingModel + std::marker::Sync + Send, C: std::marker::Sync + Send> VectorStoreIndex + for MongoDbVectorIndex<M, C> +{ async fn top_n<T: for<'a> Deserialize<'a> + std::marker::Send>( &self, query: &str, From 2477af8fbabd3acddb41097a0ee5273b877ebca5 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 11:47:37 -0400 Subject: [PATCH 54/91] fix: fix cargo toml of lancedb and mongodb --- rig-lancedb/Cargo.toml | 3 +-- rig-lancedb/examples/fixtures/lib.rs | 6 +++--- rig-mongodb/Cargo.toml | 3 +-- rig-mongodb/examples/vector_search_mongodb.rs | 2 +- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index 6ee41d8d..91877f1c 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -5,8 +5,7 @@ edition = "2021" [dependencies] lancedb = "0.10.0" -rig-core = { path = "../rig-core", version = "0.2.1" } -rig-derive = { path = "../rig-core/rig-core-derive" } +rig-core = { path = "../rig-core", version = "0.2.1", features = ["derive"] } arrow-array = "52.2.0" serde_json = "1.0.128" serde = "1.0.210" diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 415422f0..cf82fc7c 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; use rig::embeddings::embedding::Embedding; -use rig::Embeddable; +use rig::{Embeddable, OneOrMany}; use serde::Deserialize; #[derive(Embeddable, Clone, Deserialize, Debug)] @@ -48,7 +48,7 @@ pub fn schema(dims: usize) -> Schema { // Convert FakeDefinition objects and their embedding to a RecordBatch. pub fn as_record_batch( - records: Vec<(FakeDefinition, Embedding)>, + records: Vec<(FakeDefinition, OneOrMany<Embedding>)>, dims: usize, ) -> Result<RecordBatch, lancedb::arrow::arrow_schema::ArrowError> { let id = StringArray::from_iter_values( @@ -68,7 +68,7 @@ pub fn as_record_batch( let embedding = FixedSizeListArray::from_iter_primitive::<Float64Type, _, _>( records .into_iter() - .map(|(_, Embedding { vec, .. })| Some(vec.into_iter().map(Some).collect::<Vec<_>>())) + .map(|(_, embeddings)| Some(embeddings.first().vec.into_iter().map(Some).collect::<Vec<_>>())) .collect::<Vec<_>>(), dims as i32, ); diff --git a/rig-mongodb/Cargo.toml b/rig-mongodb/Cargo.toml index 78b48892..8673bda8 100644 --- a/rig-mongodb/Cargo.toml +++ b/rig-mongodb/Cargo.toml @@ -12,8 +12,7 @@ repository = "https://github.com/0xPlaygrounds/rig" [dependencies] futures = "0.3.30" mongodb = "2.8.2" -rig-core = { path = "../rig-core", version = "0.2.1" } -rig-derive = { path = "../rig-core/rig-core-derive" } +rig-core = { path = "../rig-core", version = "0.2.1", features = ["derive"] } serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.117" diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index a816c9c0..4a369001 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -84,7 +84,7 @@ async fn main() -> Result<(), anyhow::Error> { |(FakeDefinition { id, definition, .. }, embedding)| Document { id: id.clone(), definition: definition.clone(), - embedding: embedding.vec.clone(), + embedding: embedding.first().vec.clone(), }, ) .collect::<Vec<_>>(); From 485ad3bbe55af7d2dcf0b0256aa9b3e6dfb7d47b Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 11:49:19 -0400 Subject: [PATCH 55/91] refactor: use thiserror for OneOtMany::EmptyListError --- rig-core/src/one_or_many.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/rig-core/src/one_or_many.rs b/rig-core/src/one_or_many.rs index 08165873..87435537 100644 --- a/rig-core/src/one_or_many.rs +++ b/rig-core/src/one_or_many.rs @@ -12,16 +12,10 @@ pub struct OneOrMany<T> { } /// Error type for when trying to create a OneOrMany object with an empty vector. -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] +#[error("Cannot create OneOrMany with an empty vector.")] pub struct EmptyListError; -impl std::fmt::Display for EmptyListError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Cannot create OneOrMany with an empty vector.") - } -} -impl std::error::Error for EmptyListError {} - impl<T: Clone> OneOrMany<T> { /// Get the first item in the list. pub fn first(&self) -> T { From 763c3647a1ebe8caf5db35c86c7bf13126b79ebe Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 12:04:21 -0400 Subject: [PATCH 56/91] feat: add OneOrMany to in memory vector store --- rig-core/examples/calculator_chatbot.rs | 2 +- rig-core/examples/rag_dynamic_tools.rs | 2 +- rig-core/src/vector_store/in_memory_store.rs | 36 ++++++++++---------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 30c3741f..0482cdb3 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -256,7 +256,7 @@ async fn main() -> Result<(), anyhow::Error> { embeddings .into_iter() .enumerate() - .map(|(i, (tool, embedding))| (i.to_string(), tool, vec![embedding])) + .map(|(i, (tool, embedding))| (i.to_string(), tool, embedding)) .collect(), )? .index(embedding_model); diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 0ddc210b..71a72a75 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -165,7 +165,7 @@ async fn main() -> Result<(), anyhow::Error> { embeddings .into_iter() .enumerate() - .map(|(i, (tool, embedding))| (i.to_string(), tool, vec![embedding])) + .map(|(i, (tool, embedding))| (i.to_string(), tool, embedding)) .collect(), )? .index(embedding_model); diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index bfe2bd29..a41ff168 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -8,7 +8,7 @@ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; use super::{VectorStoreError, VectorStoreIndex}; -use crate::embeddings::{Embedding, EmbeddingModel}; +use crate::{embeddings::{Embedding, EmbeddingModel}, OneOrMany}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. @@ -17,7 +17,7 @@ pub struct InMemoryVectorStore<D: Serialize> { /// The embeddings are stored in a HashMap. /// Hashmap key is the document id. /// Hashmap value is a tuple of the serializable document and its corresponding embeddings. - embeddings: HashMap<String, (D, Vec<Embedding>)>, + embeddings: HashMap<String, (D, OneOrMany<Embedding>)>, } impl<D: Serialize + Eq> InMemoryVectorStore<D> { @@ -64,7 +64,7 @@ impl<D: Serialize + Eq> InMemoryVectorStore<D> { /// Returns the store with the added documents. pub fn add_documents( mut self, - documents: Vec<(String, D, Vec<Embedding>)>, + documents: Vec<(String, D, OneOrMany<Embedding>)>, ) -> Result<Self, VectorStoreError> { for (id, doc, embeddings) in documents { self.embeddings.insert(id, (doc, embeddings)); @@ -109,7 +109,7 @@ impl<D: Serialize> InMemoryVectorStore<D> { InMemoryVectorIndex::new(model, self) } - pub fn iter(&self) -> impl Iterator<Item = (&String, &(D, Vec<Embedding>))> { + pub fn iter(&self) -> impl Iterator<Item = (&String, &(D, OneOrMany<Embedding>))> { self.embeddings.iter() } @@ -132,7 +132,7 @@ impl<M: EmbeddingModel, D: Serialize> InMemoryVectorIndex<M, D> { Self { model, store } } - pub fn iter(&self) -> impl Iterator<Item = (&String, &(D, Vec<Embedding>))> { + pub fn iter(&self) -> impl Iterator<Item = (&String, &(D, OneOrMany<Embedding>))> { self.store.iter() } @@ -189,7 +189,7 @@ impl<M: EmbeddingModel + std::marker::Sync, D: Serialize + Sync + Send + Eq> Vec mod tests { use std::cmp::Reverse; - use crate::embeddings::embedding::Embedding; + use crate::{embeddings::embedding::Embedding, OneOrMany}; use super::{InMemoryVectorStore, RankingItem}; @@ -200,26 +200,26 @@ mod tests { ( "doc1".to_string(), "glarb-garb", - vec![Embedding { + OneOrMany::one(Embedding { document: "glarb-garb".to_string(), vec: vec![0.1, 0.1, 0.5], - }], + }), ), ( "doc2".to_string(), "marble-marble", - vec![Embedding { + OneOrMany::one(Embedding { document: "marble-marble".to_string(), vec: vec![0.7, -0.3, 0.0], - }], + }), ), ( "doc3".to_string(), "flumb-flumb", - vec![Embedding { + OneOrMany::one(Embedding { document: "flumb-flumb".to_string(), vec: vec![0.3, 0.7, 0.1], - }], + }), ), ]) .unwrap(); @@ -258,7 +258,7 @@ mod tests { ( "doc1".to_string(), "glarb-garb", - vec![ + OneOrMany::many(vec![ Embedding { document: "glarb-garb".to_string(), vec: vec![0.1, 0.1, 0.5], @@ -267,12 +267,12 @@ mod tests { document: "don't-choose-me".to_string(), vec: vec![-0.5, 0.9, 0.1], }, - ], + ]).unwrap(), ), ( "doc2".to_string(), "marble-marble", - vec![ + OneOrMany::many(vec![ Embedding { document: "marble-marble".to_string(), vec: vec![0.7, -0.3, 0.0], @@ -281,12 +281,12 @@ mod tests { document: "sandwich".to_string(), vec: vec![0.5, 0.5, -0.7], }, - ], + ],).unwrap() ), ( "doc3".to_string(), "flumb-flumb", - vec![ + OneOrMany::many(vec![ Embedding { document: "flumb-flumb".to_string(), vec: vec![0.3, 0.7, 0.1], @@ -295,7 +295,7 @@ mod tests { document: "banana".to_string(), vec: vec![0.1, -0.5, -0.5], }, - ], + ],).unwrap() ), ]) .unwrap(); From 3c5e59d440426c85deeaa17ff35274811da5ce29 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 12:04:41 -0400 Subject: [PATCH 57/91] style: cargo fmt --- rig-core/examples/calculator_chatbot.rs | 2 +- rig-core/examples/rag.rs | 2 +- rig-core/examples/rag_dynamic_tools.rs | 2 +- rig-core/examples/vector_search.rs | 2 +- rig-core/examples/vector_search_cohere.rs | 2 +- rig-core/src/embeddings/builder.rs | 13 ++++++++++--- rig-core/src/embeddings/tool.rs | 4 ++-- rig-core/src/providers/cohere.rs | 6 +++++- rig-core/src/tool.rs | 5 ++++- rig-core/src/vector_store/in_memory_store.rs | 14 ++++++++++---- rig-lancedb/examples/fixtures/lib.rs | 11 ++++++++++- 11 files changed, 46 insertions(+), 17 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 0482cdb3..813c48eb 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -286,4 +286,4 @@ async fn main() -> Result<(), anyhow::Error> { cli_chatbot(calculator_rag).await?; Ok(()) -} \ No newline at end of file +} diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 4cf6b05a..43270a7b 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -78,4 +78,4 @@ async fn main() -> Result<(), anyhow::Error> { println!("{}", response); Ok(()) -} \ No newline at end of file +} diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 71a72a75..cc04c209 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -184,4 +184,4 @@ async fn main() -> Result<(), anyhow::Error> { println!("{}", response); Ok(()) -} \ No newline at end of file +} diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 0766c99a..a97ef8b0 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -86,4 +86,4 @@ async fn main() -> Result<(), anyhow::Error> { println!("ID results: {:?}", id_results); Ok(()) -} \ No newline at end of file +} diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 6524ceab..16ddb775 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -81,4 +81,4 @@ async fn main() -> Result<(), anyhow::Error> { println!("Results: {:?}", results); Ok(()) -} \ No newline at end of file +} diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index d7024e8e..451a0788 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -124,7 +124,12 @@ impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M let embeddings = stream::iter(self.documents.iter().enumerate()) // Merge the embedding targets of each document into a single list of embedding targets. .flat_map(|(i, (_, embed_targets))| { - stream::iter(embed_targets.clone().into_iter().map(move |target| (i, target))) + stream::iter( + embed_targets + .clone() + .into_iter() + .map(move |target| (i, target)), + ) }) // Chunk them into N (the emebdding API limit per request). .chunks(M::MAX_DOCUMENTS) @@ -146,7 +151,9 @@ impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M HashMap::new(), |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move { embeddings.into_iter().for_each(|(i, embedding)| { - acc.entry(i).or_insert(OneOrMany::one(embedding.clone())).add(embedding.clone()); + acc.entry(i) + .or_insert(OneOrMany::one(embedding.clone())) + .add(embedding.clone()); }); Ok(acc) @@ -164,4 +171,4 @@ impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M Ok(embeddings) } -} \ No newline at end of file +} diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index 0769c82c..1efe37c9 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -8,7 +8,7 @@ use super::embeddable::EmbeddableError; pub struct EmbeddableTool { name: String, context: serde_json::Value, - embedding_docs: Vec<String> + embedding_docs: Vec<String>, } impl Embeddable for EmbeddableTool { @@ -28,4 +28,4 @@ impl EmbeddableTool { embedding_docs: tool.embedding_docs(), }) } -} \ No newline at end of file +} diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 2d2d5bf5..8f8eefd4 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -85,7 +85,11 @@ impl Client { EmbeddingModel::new(self.clone(), model, input_type, ndims) } - pub fn embeddings<D: Embeddable>(&self, model: &str, input_type: &str) -> EmbeddingsBuilder<EmbeddingModel, D> { + pub fn embeddings<D: Embeddable>( + &self, + model: &str, + input_type: &str, + ) -> EmbeddingsBuilder<EmbeddingModel, D> { EmbeddingsBuilder::new(self.embedding_model(model, input_type)) } diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 1104d5fb..e92896b8 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -3,7 +3,10 @@ use std::{collections::HashMap, pin::Pin}; use futures::Future; use serde::{Deserialize, Serialize}; -use crate::{completion::{self, ToolDefinition}, embeddings::{embeddable::EmbeddableError, tool::EmbeddableTool}}; +use crate::{ + completion::{self, ToolDefinition}, + embeddings::{embeddable::EmbeddableError, tool::EmbeddableTool}, +}; #[derive(Debug, thiserror::Error)] pub enum ToolError { diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index a41ff168..31a0ef7f 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -8,7 +8,10 @@ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; use super::{VectorStoreError, VectorStoreIndex}; -use crate::{embeddings::{Embedding, EmbeddingModel}, OneOrMany}; +use crate::{ + embeddings::{Embedding, EmbeddingModel}, + OneOrMany, +}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. @@ -267,7 +270,8 @@ mod tests { document: "don't-choose-me".to_string(), vec: vec![-0.5, 0.9, 0.1], }, - ]).unwrap(), + ]) + .unwrap(), ), ( "doc2".to_string(), @@ -281,7 +285,8 @@ mod tests { document: "sandwich".to_string(), vec: vec![0.5, 0.5, -0.7], }, - ],).unwrap() + ]) + .unwrap(), ), ( "doc3".to_string(), @@ -295,7 +300,8 @@ mod tests { document: "banana".to_string(), vec: vec![0.1, -0.5, -0.5], }, - ],).unwrap() + ]) + .unwrap(), ), ]) .unwrap(); diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index cf82fc7c..780b42ca 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -68,7 +68,16 @@ pub fn as_record_batch( let embedding = FixedSizeListArray::from_iter_primitive::<Float64Type, _, _>( records .into_iter() - .map(|(_, embeddings)| Some(embeddings.first().vec.into_iter().map(Some).collect::<Vec<_>>())) + .map(|(_, embeddings)| { + Some( + embeddings + .first() + .vec + .into_iter() + .map(Some) + .collect::<Vec<_>>(), + ) + }) .collect::<Vec<_>>(), dims as i32, ); From dc332c34a1dfbe3846f1260bc63a2a6302461bc0 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 12:07:45 -0400 Subject: [PATCH 58/91] fix: update embeddingsbuilder import path --- rig-core/examples/calculator_chatbot.rs | 2 +- rig-core/examples/rag.rs | 2 +- rig-core/examples/rag_dynamic_tools.rs | 2 +- rig-core/examples/vector_search.rs | 4 ++-- rig-core/examples/vector_search_cohere.rs | 2 +- rig-core/src/embeddings/builder.rs | 2 +- rig-mongodb/examples/vector_search_mongodb.rs | 2 +- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 813c48eb..b8739dfe 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -2,7 +2,7 @@ use anyhow::Result; use rig::{ cli_chatbot::cli_chatbot, completion::ToolDefinition, - embeddings::builder::EmbeddingsBuilder, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 43270a7b..ab1387a1 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -2,7 +2,7 @@ use std::{env, vec}; use rig::{ completion::Prompt, - embeddings::builder::EmbeddingsBuilder, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, Embeddable, diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index cc04c209..77c3f8a4 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -1,7 +1,7 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, - embeddings::builder::EmbeddingsBuilder, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index a97ef8b0..f550a68d 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - embeddings::builder::EmbeddingsBuilder, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, Embeddable, @@ -86,4 +86,4 @@ async fn main() -> Result<(), anyhow::Error> { println!("ID results: {:?}", id_results); Ok(()) -} +} \ No newline at end of file diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 16ddb775..54adc598 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - embeddings::builder::EmbeddingsBuilder, + embeddings::EmbeddingsBuilder, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, Embeddable, diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 451a0788..0fbd2fcc 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -6,7 +6,7 @@ //! use std::env; //! //! use rig::{ -//! embeddings::builder::EmbeddingsBuilder, +//! embeddings::EmbeddingsBuilder, //! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, //! vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, //! Embeddable, diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 4a369001..a5ec8517 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -5,7 +5,7 @@ use std::env; use rig::Embeddable; use rig::{ - embeddings::builder::EmbeddingsBuilder, providers::openai::Client, + embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; From db2ec987831fc857d78ed9483e92909a50e68b47 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 13:29:48 -0400 Subject: [PATCH 59/91] tests: add tests for embeddingsbuilder --- rig-core/examples/calculator_chatbot.rs | 3 +- rig-core/examples/rag_dynamic_tools.rs | 3 +- rig-core/examples/vector_search.rs | 2 +- rig-core/src/embeddings/builder.rs | 216 +++++++++++++++++- rig-core/src/embeddings/tool.rs | 6 +- rig-core/src/one_or_many.rs | 19 ++ rig-mongodb/examples/vector_search_mongodb.rs | 3 +- 7 files changed, 240 insertions(+), 12 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index b8739dfe..654f27a2 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -255,8 +255,7 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .enumerate() - .map(|(i, (tool, embedding))| (i.to_string(), tool, embedding)) + .map(|(tool, embedding)| (tool.name.clone(), tool, embedding)) .collect(), )? .index(embedding_model); diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 77c3f8a4..bdad5109 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -164,8 +164,7 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .enumerate() - .map(|(i, (tool, embedding))| (i.to_string(), tool, embedding)) + .map(|(tool, embedding)| (tool.name.clone(), tool, embedding)) .collect(), )? .index(embedding_model); diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index f550a68d..5aebe12d 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -86,4 +86,4 @@ async fn main() -> Result<(), anyhow::Error> { println!("ID results: {:?}", id_results); Ok(()) -} \ No newline at end of file +} diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 0fbd2fcc..bb7db98b 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -152,8 +152,8 @@ impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move { embeddings.into_iter().for_each(|(i, embedding)| { acc.entry(i) - .or_insert(OneOrMany::one(embedding.clone())) - .add(embedding.clone()); + .and_modify(|embeddings| embeddings.add(embedding.clone())) + .or_insert(OneOrMany::one(embedding.clone())); }); Ok(acc) @@ -172,3 +172,215 @@ impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M Ok(embeddings) } } + +#[cfg(test)] +mod tests { + use crate::{ + embeddings::{embeddable::EmbeddableError, Embedding, EmbeddingModel}, + Embeddable, + }; + + use super::EmbeddingsBuilder; + + #[derive(Clone)] + struct FakeModel; + + impl EmbeddingModel for FakeModel { + const MAX_DOCUMENTS: usize = 5; + + fn ndims(&self) -> usize { + 10 + } + + async fn embed_documents( + &self, + documents: Vec<String>, + ) -> Result<Vec<crate::embeddings::Embedding>, crate::embeddings::EmbeddingError> { + Ok(documents + .iter() + .map(|doc| Embedding { + document: doc.to_string(), + vec: vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + }) + .collect()) + } + } + + #[derive(Clone, Debug)] + struct FakeDefinition { + id: String, + definitions: Vec<String>, + } + + impl Embeddable for FakeDefinition { + type Error = EmbeddableError; + + fn embeddable(&self) -> Result<crate::OneOrMany<String>, Self::Error> { + crate::OneOrMany::many(self.definitions.clone()).map_err(EmbeddableError::new) + } + } + + fn fake_definitions() -> Vec<FakeDefinition> { + vec![ + FakeDefinition { + id: "doc0".to_string(), + definitions: vec![ + "A green alien that lives on cold planets.".to_string(), + "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() + ] + }, + FakeDefinition { + id: "doc1".to_string(), + definitions: vec![ + "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + ] + } + ] + } + + fn fake_definitions_2() -> Vec<FakeDefinition> { + vec![ + FakeDefinition { + id: "doc2".to_string(), + definitions: vec!["Another fake definitions".to_string()], + }, + FakeDefinition { + id: "doc3".to_string(), + definitions: vec!["Some fake definition".to_string()], + }, + ] + } + + #[derive(Clone, Debug)] + struct FakeDefinitionSingle { + id: String, + definition: String, + } + + impl Embeddable for FakeDefinitionSingle { + type Error = EmbeddableError; + + fn embeddable(&self) -> Result<crate::OneOrMany<String>, Self::Error> { + Ok(crate::OneOrMany::one(self.definition.clone())) + } + } + + fn fake_definitions_single() -> Vec<FakeDefinitionSingle> { + vec![ + FakeDefinitionSingle { + id: "doc0".to_string(), + definition: "A green alien that lives on cold planets.".to_string(), + }, + FakeDefinitionSingle { + id: "doc1".to_string(), + definition: "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + } + ] + } + + #[tokio::test] + async fn test_build_many() { + let fake_definitions = fake_definitions(); + + let fake_model = FakeModel; + let mut result = EmbeddingsBuilder::new(fake_model) + .documents(fake_definitions) + .unwrap() + .build() + .await + .unwrap(); + + result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| { + fake_definition_1.id.cmp(&fake_definition_2.id) + }); + + assert_eq!(result.len(), 2); + + let first_definition = &result[0]; + assert_eq!(first_definition.0.id, "doc0"); + assert_eq!(first_definition.1.len(), 2); + assert_eq!( + first_definition.1.first().document, + "A green alien that lives on cold planets.".to_string() + ); + + let second_definition = &result[1]; + assert_eq!(second_definition.0.id, "doc1"); + assert_eq!(second_definition.1.len(), 2); + assert_eq!( + second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + ) + } + + #[tokio::test] + async fn test_build_single() { + let fake_definitions = fake_definitions_single(); + + let fake_model = FakeModel; + let mut result = EmbeddingsBuilder::new(fake_model) + .documents(fake_definitions) + .unwrap() + .build() + .await + .unwrap(); + + result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| { + fake_definition_1.id.cmp(&fake_definition_2.id) + }); + + assert_eq!(result.len(), 2); + + let first_definition = &result[0]; + assert_eq!(first_definition.0.id, "doc0"); + assert_eq!(first_definition.1.len(), 1); + assert_eq!( + first_definition.1.first().document, + "A green alien that lives on cold planets.".to_string() + ); + + let second_definition = &result[1]; + assert_eq!(second_definition.0.id, "doc1"); + assert_eq!(second_definition.1.len(), 1); + assert_eq!( + second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string() + ) + } + + #[tokio::test] + async fn test_build_many_and_single() { + let fake_definitions = fake_definitions(); + let fake_definitions_single = fake_definitions_2(); + + let fake_model = FakeModel; + let mut result = EmbeddingsBuilder::new(fake_model) + .documents(fake_definitions) + .unwrap() + .documents(fake_definitions_single) + .unwrap() + .build() + .await + .unwrap(); + + result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| { + fake_definition_1.id.cmp(&fake_definition_2.id) + }); + + assert_eq!(result.len(), 4); + + let second_definition = &result[1]; + assert_eq!(second_definition.0.id, "doc1"); + assert_eq!(second_definition.1.len(), 2); + assert_eq!( + second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string() + ); + + let third_definition = &result[2]; + assert_eq!(third_definition.0.id, "doc2"); + assert_eq!(third_definition.1.len(), 1); + assert_eq!( + third_definition.1.first().document, + "Another fake definitions".to_string() + ) + } +} diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index 1efe37c9..139b11b8 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -6,9 +6,9 @@ use super::embeddable::EmbeddableError; /// Used by EmbeddingsBuilder to embed anything that implements ToolEmbedding. #[derive(Clone, Serialize, Default, Eq, PartialEq)] pub struct EmbeddableTool { - name: String, - context: serde_json::Value, - embedding_docs: Vec<String>, + pub name: String, + pub context: serde_json::Value, + pub embedding_docs: Vec<String>, } impl Embeddable for EmbeddableTool { diff --git a/rig-core/src/one_or_many.rs b/rig-core/src/one_or_many.rs index 99adf20f..aaa33b97 100644 --- a/rig-core/src/one_or_many.rs +++ b/rig-core/src/one_or_many.rs @@ -32,6 +32,11 @@ impl<T: Clone> OneOrMany<T> { self.rest.push(item); } + /// Length of all items in `OneOrMany<T>`. + pub fn len(&self) -> usize { + 1 + self.rest.len() + } + /// Create a OneOrMany object with a single item of any type. pub fn one(item: T) -> Self { OneOrMany { @@ -274,4 +279,18 @@ mod test { fn test_one_or_many_error() { assert!(OneOrMany::<String>::many(vec![]).is_err()) } + + #[test] + fn test_len_single() { + let one_or_many = OneOrMany::one("hello".to_string()); + + assert_eq!(one_or_many.len(), 1); + } + + #[test] + fn test_len_many() { + let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + assert_eq!(one_or_many.len(), 2); + } } diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index a5ec8517..78535c78 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -5,8 +5,7 @@ use std::env; use rig::Embeddable; use rig::{ - embeddings::EmbeddingsBuilder, providers::openai::Client, - vector_store::VectorStoreIndex, + embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; From 3bb2231e707800b48db6efed9760cdebc101ce1c Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 13:39:19 -0400 Subject: [PATCH 60/91] clippy: add is empty method --- rig-core/src/one_or_many.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/rig-core/src/one_or_many.rs b/rig-core/src/one_or_many.rs index aaa33b97..e9c3b007 100644 --- a/rig-core/src/one_or_many.rs +++ b/rig-core/src/one_or_many.rs @@ -37,6 +37,12 @@ impl<T: Clone> OneOrMany<T> { 1 + self.rest.len() } + /// If `OneOrMany<T>` is empty. This will always be false because you cannot create an empty `OneOrMany<T>`. + /// This methos is required when the method `len` exists. + pub fn is_empty(&self) -> bool { + false + } + /// Create a OneOrMany object with a single item of any type. pub fn one(item: T) -> Self { OneOrMany { From 3688b78e2f162d598381c53254164de73be4dac3 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 15:22:54 -0400 Subject: [PATCH 61/91] fix: add feature flag to examples in mongodb and lancedb crates --- Cargo.toml | 3 +-- rig-lancedb/Cargo.toml | 18 +++++++++++++++++- .../examples/fixtures/{lib.rs => main.rs} | 2 +- .../examples/vector_search_local_ann.rs | 2 +- .../examples/vector_search_local_enn.rs | 2 +- rig-lancedb/examples/vector_search_s3_ann.rs | 2 +- rig-mongodb/Cargo.toml | 6 +++++- 7 files changed, 27 insertions(+), 8 deletions(-) rename rig-lancedb/examples/fixtures/{lib.rs => main.rs} (98%) diff --git a/Cargo.toml b/Cargo.toml index faafb5f4..82134804 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,5 +3,4 @@ resolver = "2" members = [ "rig-core", "rig-core/rig-core-derive", "rig-mongodb", - "rig-lancedb" -] + "rig-lancedb"] diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index 91877f1c..1020e36e 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] lancedb = "0.10.0" -rig-core = { path = "../rig-core", version = "0.2.1", features = ["derive"] } +rig-core = { path = "../rig-core", version = "0.2.1" } arrow-array = "52.2.0" serde_json = "1.0.128" serde = "1.0.210" @@ -14,3 +14,19 @@ futures = "0.3.30" [dev-dependencies] tokio = "1.40.0" anyhow = "1.0.89" + +[[example]] +name = "fixtures" +required-features = ["rig-core/derive"] + +[[example]] +name = "vector_search_local_ann" +required-features = ["rig-core/derive"] + +[[example]] +name = "vector_search_local_enn" +required-features = ["rig-core/derive"] + +[[example]] +name = "vector_search_s3_ann" +required-features = ["rig-core/derive"] diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/main.rs similarity index 98% rename from rig-lancedb/examples/fixtures/lib.rs rename to rig-lancedb/examples/fixtures/main.rs index 780b42ca..d6e02a5a 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/main.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; -use rig::embeddings::embedding::Embedding; +use rig::embeddings::Embedding; use rig::{Embeddable, OneOrMany}; use serde::Deserialize; diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 1b7870fb..5518f0cb 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -10,7 +10,7 @@ use rig::{ }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -#[path = "./fixtures/lib.rs"] +#[path = "./fixtures/main.rs"] mod fixture; #[tokio::main] diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 630acc1a..7d0bec9c 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -9,7 +9,7 @@ use rig::{ }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -#[path = "./fixtures/lib.rs"] +#[path = "./fixtures/main.rs"] mod fixture; #[tokio::main] diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 8c10409b..b6bdcb3e 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -10,7 +10,7 @@ use rig::{ }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -#[path = "./fixtures/lib.rs"] +#[path = "./fixtures/main.rs"] mod fixture; // Note: see docs to deploy LanceDB on other cloud providers such as google and azure. diff --git a/rig-mongodb/Cargo.toml b/rig-mongodb/Cargo.toml index 8673bda8..66fef77a 100644 --- a/rig-mongodb/Cargo.toml +++ b/rig-mongodb/Cargo.toml @@ -12,7 +12,7 @@ repository = "https://github.com/0xPlaygrounds/rig" [dependencies] futures = "0.3.30" mongodb = "2.8.2" -rig-core = { path = "../rig-core", version = "0.2.1", features = ["derive"] } +rig-core = { path = "../rig-core", version = "0.2.1" } serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.117" @@ -21,3 +21,7 @@ tracing = "0.1.40" [dev-dependencies] anyhow = "1.0.86" tokio = { version = "1.38.0", features = ["macros"] } + +[[example]] +name = "vector_search_mongodb" +required-features = ["rig-core/derive"] \ No newline at end of file From db8d188df933e787b2896a15554732600e64ffc9 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 15:27:36 -0400 Subject: [PATCH 62/91] fix: move lancedb fixtures into it's own file --- rig-lancedb/examples/{fixtures/main.rs => fixtures.rs} | 0 rig-lancedb/examples/vector_search_local_ann.rs | 2 +- rig-lancedb/examples/vector_search_local_enn.rs | 2 +- rig-lancedb/examples/vector_search_s3_ann.rs | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename rig-lancedb/examples/{fixtures/main.rs => fixtures.rs} (100%) diff --git a/rig-lancedb/examples/fixtures/main.rs b/rig-lancedb/examples/fixtures.rs similarity index 100% rename from rig-lancedb/examples/fixtures/main.rs rename to rig-lancedb/examples/fixtures.rs diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 5518f0cb..09294a67 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -10,7 +10,7 @@ use rig::{ }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -#[path = "./fixtures/main.rs"] +#[path = "./fixtures.rs"] mod fixture; #[tokio::main] diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 7d0bec9c..f3e37d1b 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -9,7 +9,7 @@ use rig::{ }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -#[path = "./fixtures/main.rs"] +#[path = "./fixtures.rs"] mod fixture; #[tokio::main] diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index b6bdcb3e..cb4f8e0f 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -10,7 +10,7 @@ use rig::{ }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -#[path = "./fixtures/main.rs"] +#[path = "./fixtures.rs"] mod fixture; // Note: see docs to deploy LanceDB on other cloud providers such as google and azure. From 803b79211620b2f251f85a2c5a69cf2ea7f87f00 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 15:29:55 -0400 Subject: [PATCH 63/91] fix: add dummy main function in fextures.rs for compiler --- rig-lancedb/examples/fixtures.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rig-lancedb/examples/fixtures.rs b/rig-lancedb/examples/fixtures.rs index d6e02a5a..8ed72dba 100644 --- a/rig-lancedb/examples/fixtures.rs +++ b/rig-lancedb/examples/fixtures.rs @@ -13,6 +13,8 @@ pub struct FakeDefinition { pub definition: String, } +fn main() {} + pub fn fake_definitions() -> Vec<FakeDefinition> { vec![ FakeDefinition { From 97775d65b13e9453629960191a555068e010b1e7 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 15:38:34 -0400 Subject: [PATCH 64/91] fix: revert fixture file, remove fixtures from cargo toml examples --- rig-lancedb/Cargo.toml | 4 ---- rig-lancedb/examples/{fixtures.rs => fixtures/lib.rs} | 2 -- rig-lancedb/examples/vector_search_local_ann.rs | 4 ++-- rig-lancedb/examples/vector_search_local_enn.rs | 4 ++-- rig-lancedb/examples/vector_search_s3_ann.rs | 4 ++-- 5 files changed, 6 insertions(+), 12 deletions(-) rename rig-lancedb/examples/{fixtures.rs => fixtures/lib.rs} (99%) diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index 1020e36e..b93644e9 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -15,10 +15,6 @@ futures = "0.3.30" tokio = "1.40.0" anyhow = "1.0.89" -[[example]] -name = "fixtures" -required-features = ["rig-core/derive"] - [[example]] name = "vector_search_local_ann" required-features = ["rig-core/derive"] diff --git a/rig-lancedb/examples/fixtures.rs b/rig-lancedb/examples/fixtures/lib.rs similarity index 99% rename from rig-lancedb/examples/fixtures.rs rename to rig-lancedb/examples/fixtures/lib.rs index 8ed72dba..d6e02a5a 100644 --- a/rig-lancedb/examples/fixtures.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -13,8 +13,6 @@ pub struct FakeDefinition { pub definition: String, } -fn main() {} - pub fn fake_definitions() -> Vec<FakeDefinition> { vec![ FakeDefinition { diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 09294a67..23ad47b5 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; +use fixtures::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ @@ -10,7 +10,7 @@ use rig::{ }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -#[path = "./fixtures.rs"] +#[path = "./fixtures/lib.rs"] mod fixture; #[tokio::main] diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index f3e37d1b..dc492df4 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema}; +use fixtures::{as_record_batch, fake_definitions, schema}; use rig::{ embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, @@ -9,7 +9,7 @@ use rig::{ }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -#[path = "./fixtures.rs"] +#[path = "./fixtures/lib.rs"] mod fixture; #[tokio::main] diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index cb4f8e0f..a521153f 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; +use fixtures::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, @@ -10,7 +10,7 @@ use rig::{ }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -#[path = "./fixtures.rs"] +#[path = "./fixtures/lib.rs"] mod fixture; // Note: see docs to deploy LanceDB on other cloud providers such as google and azure. From 05ef716e4623b8f4d24c4781492113ead90ce4a9 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 15:49:31 -0400 Subject: [PATCH 65/91] fix: update fixture import in lancedb examples --- rig-lancedb/examples/vector_search_local_ann.rs | 2 +- rig-lancedb/examples/vector_search_local_enn.rs | 2 +- rig-lancedb/examples/vector_search_s3_ann.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 23ad47b5..1b7870fb 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixtures::{as_record_batch, fake_definitions, schema, FakeDefinition}; +use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index dc492df4..630acc1a 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixtures::{as_record_batch, fake_definitions, schema}; +use fixture::{as_record_batch, fake_definitions, schema}; use rig::{ embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index a521153f..8c10409b 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixtures::{as_record_batch, fake_definitions, schema, FakeDefinition}; +use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, From d75e4bb46bb233c97761afdd51460d6c2947073f Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Fri, 18 Oct 2024 16:09:03 -0400 Subject: [PATCH 66/91] refactor: rename D to T in embeddingsbuilder generics --- rig-core/src/embeddings/builder.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index bb7db98b..8acbdc85 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -73,12 +73,12 @@ use crate::{ }; /// Builder for creating a collection of embeddings. -pub struct EmbeddingsBuilder<M: EmbeddingModel, D: Embeddable> { +pub struct EmbeddingsBuilder<M: EmbeddingModel, T: Embeddable> { model: M, - documents: Vec<(D, OneOrMany<String>)>, + documents: Vec<(T, OneOrMany<String>)>, } -impl<M: EmbeddingModel, D: Embeddable> EmbeddingsBuilder<M, D> { +impl<M: EmbeddingModel, T: Embeddable> EmbeddingsBuilder<M, T> { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { @@ -88,7 +88,7 @@ impl<M: EmbeddingModel, D: Embeddable> EmbeddingsBuilder<M, D> { } /// Add a document that implements `Embeddable` to the builder. - pub fn document(mut self, document: D) -> Result<Self, D::Error> { + pub fn document(mut self, document: T) -> Result<Self, T::Error> { let embed_targets = document.embeddable()?; self.documents.push((document, embed_targets)); @@ -96,7 +96,7 @@ impl<M: EmbeddingModel, D: Embeddable> EmbeddingsBuilder<M, D> { } /// Add many documents that implement `Embeddable` to the builder. - pub fn documents(mut self, documents: Vec<D>) -> Result<Self, D::Error> { + pub fn documents(mut self, documents: Vec<T>) -> Result<Self, T::Error> { for doc in documents.into_iter() { let embed_targets = doc.embeddable()?; @@ -107,11 +107,11 @@ impl<M: EmbeddingModel, D: Embeddable> EmbeddingsBuilder<M, D> { } } -impl<M: EmbeddingModel, D: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M, D> { +impl<M: EmbeddingModel, T: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M, T> { /// Generate embeddings for all documents in the builder. /// The method only applies when documents in the builder each contain multiple embedding targets. /// Returns a vector of tuples, where the first element is the document and the second element is the vector of embeddings. - pub async fn build(&self) -> Result<Vec<(D, OneOrMany<Embedding>)>, EmbeddingError> { + pub async fn build(&self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> { // Use this for reference later to merge a document back with its embeddings. let documents_map = self .documents From 2865134820f484fa56b4abbb8a5e25c113ef8aad Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Mon, 21 Oct 2024 09:37:45 -0400 Subject: [PATCH 67/91] refactor: remove clone --- rig-core/examples/calculator_chatbot.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 654f27a2..3f9f0b1b 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -25,7 +25,7 @@ struct MathError; #[error("Init error")] struct InitError; -#[derive(Deserialize, Serialize, Clone)] +#[derive(Deserialize, Serialize)] struct Add; impl Tool for Add { const NAME: &'static str = "add"; @@ -77,7 +77,7 @@ impl ToolEmbedding for Add { fn context(&self) -> Self::Context {} } -#[derive(Deserialize, Serialize, Clone)] +#[derive(Deserialize, Serialize)] struct Subtract; impl Tool for Subtract { const NAME: &'static str = "subtract"; From 55e240988321ea8ae30ad49591d665ab2ff6c0dc Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Mon, 21 Oct 2024 13:28:21 -0400 Subject: [PATCH 68/91] PR: update builder, docstrings, and std::markers tags --- rig-core/src/embeddings/builder.rs | 125 +++++++++---------- rig-core/src/vector_store/in_memory_store.rs | 2 +- rig-core/src/vector_store/mod.rs | 2 +- rig-lancedb/src/lib.rs | 4 +- rig-mongodb/src/lib.rs | 4 +- 5 files changed, 66 insertions(+), 71 deletions(-) diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 8acbdc85..288b6eb3 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -1,67 +1,5 @@ //! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. //! Only types that implement the [Embeddable] trait can be added to the [EmbeddingsBuilder]. -//! -//! # Example -//! ```rust -//! use std::env; -//! -//! use rig::{ -//! embeddings::EmbeddingsBuilder, -//! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, -//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, -//! Embeddable, -//! }; -//! use serde::{Deserialize, Serialize}; -//! -//! // Shape of data that needs to be RAG'ed. -//! // The definition field will be used to generate embeddings. -//! #[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -//! struct FakeDefinition { -//! id: String, -//! word: String, -//! #[embed] -//! definitions: Vec<String>, -//! } -//! -//! // Create OpenAI client -//! let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); -//! let openai_client = Client::new(&openai_api_key); -//! -//! let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); -//! -//! let embeddings = EmbeddingsBuilder::new(model.clone()) -//! .documents(vec![ -//! FakeDefinition { -//! id: "doc0".to_string(), -//! word: "flurbo".to_string(), -//! definitions: vec![ -//! "A green alien that lives on cold planets.".to_string(), -//! "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() -//! ] -//! }, -//! FakeDefinition { -//! id: "doc1".to_string(), -//! word: "glarb-glarb".to_string(), -//! definitions: vec![ -//! "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), -//! "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() -//! ] -//! }, -//! FakeDefinition { -//! id: "doc2".to_string(), -//! word: "linglingdong".to_string(), -//! definitions: vec![ -//! "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), -//! "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() -//! ] -//! }, -//! ])? -//! .build() -//! .await?; -//! -//! // Use the generated embeddings -//! // ... -//! ``` use std::{cmp::max, collections::HashMap}; @@ -107,11 +45,68 @@ impl<M: EmbeddingModel, T: Embeddable> EmbeddingsBuilder<M, T> { } } +/// # Example +/// ```rust +/// use std::env; +/// +/// use rig::{ +/// embeddings::EmbeddingsBuilder, +/// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, +/// vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, +/// Embeddable, +/// }; +/// use serde::{Deserialize, Serialize}; +/// +/// // Shape of data that needs to be RAG'ed. +/// // The definition field will be used to generate embeddings. +/// #[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +/// struct FakeDefinition { +/// id: String, +/// word: String, +/// #[embed] +/// definitions: Vec<String>, +/// } +/// +/// // Create OpenAI client +/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); +/// let openai_client = Client::new(&openai_api_key); +/// +/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); +/// +/// let embeddings = EmbeddingsBuilder::new(model.clone()) +/// .documents(vec![ +/// FakeDefinition { +/// id: "doc0".to_string(), +/// word: "flurbo".to_string(), +/// definitions: vec![ +/// "A green alien that lives on cold planets.".to_string(), +/// "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() +/// ] +/// }, +/// FakeDefinition { +/// id: "doc1".to_string(), +/// word: "glarb-glarb".to_string(), +/// definitions: vec![ +/// "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), +/// "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() +/// ] +/// }, +/// FakeDefinition { +/// id: "doc2".to_string(), +/// word: "linglingdong".to_string(), +/// definitions: vec![ +/// "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), +/// "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() +/// ] +/// }, +/// ])? +/// .build() +/// .await?; +/// ``` impl<M: EmbeddingModel, T: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M, T> { /// Generate embeddings for all documents in the builder. - /// The method only applies when documents in the builder each contain multiple embedding targets. - /// Returns a vector of tuples, where the first element is the document and the second element is the vector of embeddings. - pub async fn build(&self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> { + /// Returns a vector of tuples, where the first element is the document and the second element is the embeddings (either one embedding or many). + pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> { // Use this for reference later to merge a document back with its embeddings. let documents_map = self .documents diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 31a0ef7f..208b5f13 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -148,7 +148,7 @@ impl<M: EmbeddingModel, D: Serialize> InMemoryVectorIndex<M, D> { } } -impl<M: EmbeddingModel + std::marker::Sync, D: Serialize + Sync + Send + Eq> VectorStoreIndex +impl<M: EmbeddingModel + Sync, D: Serialize + Sync + Send + Eq> VectorStoreIndex for InMemoryVectorIndex<M, D> { async fn top_n<T: for<'a> Deserialize<'a>>( diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 396b5514..38e45d0e 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -23,7 +23,7 @@ pub enum VectorStoreError { pub trait VectorStoreIndex: Send + Sync { /// Get the top n documents based on the distance to the given query. /// The result is a list of tuples of the form (score, id, document) - fn top_n<T: for<'a> Deserialize<'a> + std::marker::Send>( + fn top_n<T: for<'a> Deserialize<'a> + Send>( &self, query: &str, n: usize, diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index edcc51e5..bb4917e8 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -230,8 +230,8 @@ impl<M: EmbeddingModel> LanceDbVectorStore<M> { } } -impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for LanceDbVectorStore<M> { - async fn top_n<T: for<'a> Deserialize<'a> + std::marker::Send>( +impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorStore<M> { + async fn top_n<T: for<'a> Deserialize<'a> + Send>( &self, query: &str, n: usize, diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 4778e454..b2ff5d1b 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -141,10 +141,10 @@ impl SearchParams { } } -impl<M: EmbeddingModel + std::marker::Sync + Send, C: std::marker::Sync + Send> VectorStoreIndex +impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex for MongoDbVectorIndex<M, C> { - async fn top_n<T: for<'a> Deserialize<'a> + std::marker::Send>( + async fn top_n<T: for<'a> Deserialize<'a> + Send>( &self, query: &str, n: usize, From 0cbc5aa8eebb5bae2ffb3cac21cff09fec3caf4e Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Mon, 21 Oct 2024 13:29:38 -0400 Subject: [PATCH 69/91] style: replace add with push --- rig-core/src/embeddings/builder.rs | 2 +- rig-core/src/one_or_many.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 288b6eb3..887dbf0e 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -147,7 +147,7 @@ impl<M: EmbeddingModel, T: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move { embeddings.into_iter().for_each(|(i, embedding)| { acc.entry(i) - .and_modify(|embeddings| embeddings.add(embedding.clone())) + .and_modify(|embeddings| embeddings.push(embedding.clone())) .or_insert(OneOrMany::one(embedding.clone())); }); diff --git a/rig-core/src/one_or_many.rs b/rig-core/src/one_or_many.rs index e9c3b007..3b92f860 100644 --- a/rig-core/src/one_or_many.rs +++ b/rig-core/src/one_or_many.rs @@ -28,7 +28,7 @@ impl<T: Clone> OneOrMany<T> { } /// After `OneOrMany<T>` is created, add an item of type T to the `rest`. - pub fn add(&mut self, item: T) { + pub fn push(&mut self, item: T) { self.rest.push(item); } From 1176e2f480d330af2ad6e102b08a6493c8e0d402 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Mon, 21 Oct 2024 13:49:30 -0400 Subject: [PATCH 70/91] fix: fix mongodb example --- rig-mongodb/examples/vector_search_mongodb.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 78535c78..b095c060 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -13,6 +13,7 @@ use rig_mongodb::{MongoDbVectorStore, SearchParams}; // The definition field will be used to generate embeddings. #[derive(Embeddable, Clone, Deserialize, Debug)] struct FakeDefinition { + #[serde(rename = "_id")] id: String, #[embed] definition: String, @@ -93,11 +94,12 @@ async fn main() -> Result<(), anyhow::Error> { Err(e) => println!("Error adding documents: {:?}", e), }; - // Create a vector index on our vector store + // Create a vector index on our vector store. + // Note: a vector index called "vector_index" must exist on the MongoDB collection you are querying. // IMPORTANT: Reuse the same model that was used to generate the embeddings let index = MongoDbVectorStore::new(collection).index( model, - "definitions_vector_index", + "vector_index", SearchParams::new("embedding"), ); From f34a5dd28eba407170921a12563af68fdcec2b4c Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Mon, 21 Oct 2024 15:05:27 -0400 Subject: [PATCH 71/91] fix: update lancedb and mongodb doc example --- rig-lancedb/src/lib.rs | 70 +++++++++++++----------- rig-mongodb/src/lib.rs | 121 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 157 insertions(+), 34 deletions(-) diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index bb4917e8..18d87d7f 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -23,68 +23,72 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { /// # Example /// ``` /// use std::{env, sync::Arc}; -/// + /// use arrow_array::RecordBatchIterator; -/// use fixture::{as_record_batch, schema}; +/// use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; +/// use lancedb::index::vector::IvfPqIndexBuilder; +/// use rig::vector_store::VectorStoreIndex; /// use rig::{ -/// embeddings::{EmbeddingModel, EmbeddingsBuilder}, +/// embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, /// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, -/// vector_store::VectorStoreIndexDyn, /// }; /// use rig_lancedb::{LanceDbVectorStore, SearchParams}; -/// use serde::Deserialize; /// -/// #[derive(Deserialize, Debug)] -/// pub struct VectorSearchResult { -/// pub id: String, -/// pub content: String, -/// } +/// #[path = "../examples/fixtures/lib.rs"] +/// mod fixture; /// /// // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). /// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); /// let openai_client = Client::new(&openai_api_key); /// -/// // Select the embedding model and generate our embeddings +/// // Select an embedding model. /// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); /// +/// // Initialize LanceDB locally. +/// let db = lancedb::connect("data/lancedb-store").execute().await?; +/// +/// // Generate embeddings for the test data. /// let embeddings = EmbeddingsBuilder::new(model.clone()) -/// .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") -/// .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") -/// .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") +/// .documents(fake_definitions())? +/// // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. +/// .documents( +/// (0..256) +/// .map(|i| FakeDefinition { +/// id: format!("doc{}", i), +/// definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() +/// }) +/// .collect(), +/// )? /// .build() /// .await?; /// -/// // Define search_params params that will be used by the vector store to perform the vector search. -/// let search_params = SearchParams::default(); -/// -/// // Initialize LanceDB locally. -/// let db = lancedb::connect("data/lancedb-store").execute().await?; -/// /// // Create table with embeddings. /// let record_batch = as_record_batch(embeddings, model.ndims()); /// let table = db /// .create_table( -/// "definitions", -/// RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), +/// "definitions", +/// RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), +/// ) +/// .execute() +/// .await?; +/// +/// // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information +/// table +/// .create_index( +/// &["embedding"], +/// lancedb::index::Index::IvfPq(IvfPqIndexBuilder::default()), /// ) /// .execute() /// .await?; /// +/// // Define search_params params that will be used by the vector store to perform the vector search. +/// let search_params = SearchParams::default(); /// let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; /// /// // Query the index /// let results = vector_store -/// .top_n("My boss says I zindle too much, what does that mean?", 1) -/// .await? -/// .into_iter() -/// .map(|(score, id, doc)| { -/// anyhow::Ok(( -/// score, -/// id, -/// serde_json::from_value::<VectorSearchResult>(doc)?, -/// )) -/// }) -/// .collect::<Result<Vec<_>, _>>()?; +/// .top_n::<FakeDefinition>("My boss says I zindle too much, what does that mean?", 1) +/// .await?; /// /// println!("Results: {:?}", results); /// ``` diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index b2ff5d1b..3201c648 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -7,7 +7,126 @@ use rig::{ }; use serde::Deserialize; -/// A MongoDB vector store. +/// # Example +/// ``` +/// use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; +/// use rig::providers::openai::TEXT_EMBEDDING_ADA_002; +/// use serde::{Deserialize, Serialize}; +/// use std::env; + +/// use rig::Embeddable; +/// use rig::{ +/// embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, +/// }; +/// use rig_mongodb::{MongoDbVectorStore, SearchParams}; + +/// // Shape of data that needs to be RAG'ed. +/// // The definition field will be used to generate embeddings. +/// #[derive(Embeddable, Clone, Deserialize, Debug)] +/// struct FakeDefinition { +/// #[serde(rename = "_id")] +/// id: String, +/// #[embed] +/// definition: String, +/// } + +/// #[derive(Clone, Deserialize, Debug, Serialize)] +/// struct Link { +/// word: String, +/// link: String, +/// } + +/// // Shape of the document to be stored in MongoDB, with embeddings. +/// #[derive(Serialize, Debug)] +/// struct Document { +/// #[serde(rename = "_id")] +/// id: String, +/// definition: String, +/// embedding: Vec<f64>, +/// } +/// // Initialize OpenAI client +/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); +/// let openai_client = Client::new(&openai_api_key); + +/// // Initialize MongoDB client +/// let mongodb_connection_string = +/// env::var("MONGODB_CONNECTION_STRING").expect("MONGODB_CONNECTION_STRING not set"); +/// let options = ClientOptions::parse(mongodb_connection_string) +/// .await +/// .expect("MongoDB connection string should be valid"); + +/// let mongodb_client = +/// MongoClient::with_options(options).expect("MongoDB client options should be valid"); + +/// // Initialize MongoDB vector store +/// let collection: Collection<Document> = mongodb_client +/// .database("knowledgebase") +/// .collection("context"); + +/// // Select the embedding model and generate our embeddings +/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + +/// let fake_definitions = vec![ +/// FakeDefinition { +/// id: "doc0".to_string(), +/// definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), +/// }, +/// FakeDefinition { +/// id: "doc1".to_string(), +/// definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), +/// }, +/// FakeDefinition { +/// id: "doc2".to_string(), +/// definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), +/// } +/// ]; + +/// let embeddings = EmbeddingsBuilder::new(model.clone()) +/// .documents(fake_definitions)? +/// .build() +/// .await?; + +/// let mongo_documents = embeddings +/// .iter() +/// .map( +/// |(FakeDefinition { id, definition, .. }, embedding)| Document { +/// id: id.clone(), +/// definition: definition.clone(), +/// embedding: embedding.first().vec.clone(), +/// }, +/// ) +/// .collect::<Vec<_>>(); + +/// match collection.insert_many(mongo_documents, None).await { +/// Ok(_) => println!("Documents added successfully"), +/// Err(e) => println!("Error adding documents: {:?}", e), +/// }; + +/// // Create a vector index on our vector store. +/// // Note: a vector index called "vector_index" must exist on the MongoDB collection you are querying. +/// // IMPORTANT: Reuse the same model that was used to generate the embeddings +/// let index = MongoDbVectorStore::new(collection).index( +/// model, +/// "vector_index", +/// SearchParams::new("embedding"), +/// ); + +/// // Query the index +/// let results = index +/// .top_n::<FakeDefinition>("What is a linglingdong?", 1) +/// .await?; + +/// println!("Results: {:?}", results); + +/// let id_results = index +/// .top_n_ids("What is a linglingdong?", 1) +/// .await? +/// .into_iter() +/// .map(|(score, id)| (score, id)) +/// .collect::<Vec<_>>(); + +/// println!("ID results: {:?}", id_results); +/// ``` pub struct MongoDbVectorStore<C> { collection: mongodb::Collection<C>, } From f796e1265df3f6bac8da3206b2c6ec666ea37a65 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Mon, 21 Oct 2024 15:54:38 -0400 Subject: [PATCH 72/91] fix: typo --- rig-core/src/one_or_many.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/src/one_or_many.rs b/rig-core/src/one_or_many.rs index 3b92f860..64584603 100644 --- a/rig-core/src/one_or_many.rs +++ b/rig-core/src/one_or_many.rs @@ -38,7 +38,7 @@ impl<T: Clone> OneOrMany<T> { } /// If `OneOrMany<T>` is empty. This will always be false because you cannot create an empty `OneOrMany<T>`. - /// This methos is required when the method `len` exists. + /// This method is required when the method `len` exists. pub fn is_empty(&self) -> bool { false } From 223139e8325a3ecca6d04cecc42893def6beb7db Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Tue, 22 Oct 2024 15:41:53 -0400 Subject: [PATCH 73/91] docs: add and fix docstrings and examples --- rig-core/examples/calculator_chatbot.rs | 7 +- rig-core/examples/rag.rs | 15 +- rig-core/examples/rag_dynamic_tools.rs | 7 +- rig-core/examples/vector_search.rs | 9 +- rig-core/examples/vector_search_cohere.rs | 9 +- rig-core/src/lib.rs | 2 +- rig-core/src/vector_store/in_memory_store.rs | 27 +++ rig-core/src/vector_store/mod.rs | 3 + .../examples/vector_search_local_ann.rs | 17 +- .../examples/vector_search_local_enn.rs | 13 +- rig-lancedb/examples/vector_search_s3_ann.rs | 13 +- rig-lancedb/src/lib.rs | 147 ++++++-------- rig-mongodb/src/lib.rs | 190 +++++++----------- 13 files changed, 196 insertions(+), 263 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 3f9f0b1b..723bfada 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -252,12 +252,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(tool, embedding)| (tool.name.clone(), tool, embedding)) - .collect(), - )? + .add_documents_with_id(embeddings, "name")? .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index ab1387a1..d03902d1 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -10,8 +10,9 @@ use rig::{ use serde::Serialize; // Shape of data that needs to be RAG'ed. -// The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Debug, Serialize, Eq, PartialEq, Default)] +// A vector search needs to be performed on the definitions, so we derive the `Embeddable` trait for `FakeDefinition` +// and tag that field with `#[embed]`. +#[derive(Embeddable, Serialize, Clone, Debug, Eq, PartialEq, Default)] struct FakeDefinition { id: String, #[embed] @@ -26,6 +27,7 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + // Generate embeddings for the definitions of all the documents using the specified embedding model. let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .documents(vec![ FakeDefinition { @@ -54,14 +56,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) - .collect(), - )? + .add_documents_with_id(embeddings, "id")? .index(embedding_model); let rag_agent = openai_client.agent("gpt-4") diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index bdad5109..c3a2c251 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -161,12 +161,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(tool, embedding)| (tool.name.clone(), tool, embedding)) - .collect(), - )? + .add_documents_with_id(embeddings, "name")? .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 5aebe12d..b40f271c 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -57,14 +57,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) - .collect(), - )? + .add_documents_with_id(embeddings, "id")? .index(model); let results = index diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 54adc598..da1e474b 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -58,14 +58,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents( - embeddings - .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) - .collect(), - )? + .add_documents_with_id(embeddings, "id")? .index(search_model); let results = index diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 07c59f96..6b337073 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -78,7 +78,7 @@ pub mod tool; pub mod vector_store; // Re-export commonly used types and traits -pub use embeddings::embeddable::Embeddable; +pub use embeddings::Embeddable; pub use one_or_many::OneOrMany; #[cfg(feature = "derive")] diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 208b5f13..5ab9d6e7 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -76,6 +76,33 @@ impl<D: Serialize + Eq> InMemoryVectorStore<D> { Ok(self) } + /// Add documents to the store. Define the name of the field in the document that contains the id. + /// Returns the store with the added documents. + pub fn add_documents_with_id( + mut self, + documents: Vec<(D, OneOrMany<Embedding>)>, + id_field: &str, + ) -> Result<Self, VectorStoreError> { + for (doc, embeddings) in documents { + if let serde_json::Value::Object(o) = + serde_json::to_value(&doc).map_err(VectorStoreError::JsonError)? + { + match o.get(id_field) { + Some(serde_json::Value::String(s)) => { + self.embeddings.insert(s.clone(), (doc, embeddings)); + } + _ => { + return Err(VectorStoreError::MissingIdError(format!( + "Document does not have a field {id_field}" + ))); + } + } + }; + } + + Ok(self) + } + /// Get the document by its id and deserialize it into the given type. pub fn get_document<T: for<'a> Deserialize<'a>>( &self, diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 38e45d0e..044d8c2a 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -17,6 +17,9 @@ pub enum VectorStoreError { #[error("Datastore error: {0}")] DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync>), + + #[error("Missing Id: {0}")] + MissingIdError(String), } /// Trait for vector store indexes diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 1b7870fb..7ffd6b12 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -3,12 +3,12 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::index::vector::IvfPqIndexBuilder; -use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + vector_store::VectorStoreIndex, }; -use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use rig_lancedb::{LanceDbVectorIndex, SearchParams}; #[path = "./fixtures/lib.rs"] mod fixture; @@ -40,12 +40,13 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Create table with embeddings. - let record_batch = as_record_batch(embeddings, model.ndims()); let table = db .create_table( "definitions", - RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + RecordBatchIterator::new( + vec![as_record_batch(embeddings, model.ndims())], + Arc::new(schema(model.ndims())), + ), ) .execute() .await?; @@ -61,10 +62,10 @@ async fn main() -> Result<(), anyhow::Error> { // Define search_params params that will be used by the vector store to perform the vector search. let search_params = SearchParams::default(); - let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + let vector_store_index = LanceDbVectorIndex::new(table, model, "id", search_params).await?; // Query the index - let results = vector_store + let results = vector_store_index .top_n::<FakeDefinition>("My boss says I zindle too much, what does that mean?", 1) .await?; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 630acc1a..859442be 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -3,11 +3,11 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definitions, schema}; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, }; -use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use rig_lancedb::{LanceDbVectorIndex, SearchParams}; #[path = "./fixtures/lib.rs"] mod fixture; @@ -33,17 +33,18 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; - // Create table with embeddings. - let record_batch = as_record_batch(embeddings, model.ndims()); let table = db .create_table( "definitions", - RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + RecordBatchIterator::new( + vec![as_record_batch(embeddings, model.ndims())], + Arc::new(schema(model.ndims())), + ), ) .execute() .await?; - let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?; // Query the index let results = vector_store diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 8c10409b..824deda0 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,11 +4,11 @@ use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; -use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use rig_lancedb::{LanceDbVectorIndex, SearchParams}; #[path = "./fixtures/lib.rs"] mod fixture; @@ -46,12 +46,13 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Create table with embeddings. - let record_batch = as_record_batch(embeddings, model.ndims()); let table = db .create_table( "definitions", - RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + RecordBatchIterator::new( + vec![as_record_batch(embeddings, model.ndims())], + Arc::new(schema(model.ndims())), + ), ) .execute() .await?; @@ -73,7 +74,7 @@ async fn main() -> Result<(), anyhow::Error> { // Define search_params params that will be used by the vector store to perform the vector search. let search_params = SearchParams::default().distance_type(DistanceType::Cosine); - let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?; // Query the index let results = vector_store diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 18d87d7f..eaaffbe3 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -20,79 +20,17 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { VectorStoreError::JsonError(e) } +/// Type on which vector searches can be performed for a lanceDb table. /// # Example /// ``` -/// use std::{env, sync::Arc}; - -/// use arrow_array::RecordBatchIterator; -/// use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; -/// use lancedb::index::vector::IvfPqIndexBuilder; -/// use rig::vector_store::VectorStoreIndex; -/// use rig::{ -/// embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, -/// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, -/// }; -/// use rig_lancedb::{LanceDbVectorStore, SearchParams}; -/// -/// #[path = "../examples/fixtures/lib.rs"] -/// mod fixture; -/// -/// // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). -/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); -/// let openai_client = Client::new(&openai_api_key); -/// -/// // Select an embedding model. -/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); -/// -/// // Initialize LanceDB locally. -/// let db = lancedb::connect("data/lancedb-store").execute().await?; -/// -/// // Generate embeddings for the test data. -/// let embeddings = EmbeddingsBuilder::new(model.clone()) -/// .documents(fake_definitions())? -/// // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. -/// .documents( -/// (0..256) -/// .map(|i| FakeDefinition { -/// id: format!("doc{}", i), -/// definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() -/// }) -/// .collect(), -/// )? -/// .build() -/// .await?; +/// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; +/// use rig::embeddings::EmbeddingModel; /// -/// // Create table with embeddings. -/// let record_batch = as_record_batch(embeddings, model.ndims()); -/// let table = db -/// .create_table( -/// "definitions", -/// RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), -/// ) -/// .execute() -/// .await?; -/// -/// // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information -/// table -/// .create_index( -/// &["embedding"], -/// lancedb::index::Index::IvfPq(IvfPqIndexBuilder::default()), -/// ) -/// .execute() -/// .await?; -/// -/// // Define search_params params that will be used by the vector store to perform the vector search. -/// let search_params = SearchParams::default(); -/// let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; -/// -/// // Query the index -/// let results = vector_store -/// .top_n::<FakeDefinition>("My boss says I zindle too much, what does that mean?", 1) -/// .await?; -/// -/// println!("Results: {:?}", results); +/// fn create_index(table: lancedb::Table, model: EmbeddingModel) { +/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; +/// } /// ``` -pub struct LanceDbVectorStore<M: EmbeddingModel> { +pub struct LanceDbVectorIndex<M: EmbeddingModel> { /// Defines which model is used to generate embeddings for the vector store. model: M, /// LanceDB table containing embeddings. @@ -103,7 +41,24 @@ pub struct LanceDbVectorStore<M: EmbeddingModel> { search_params: SearchParams, } -impl<M: EmbeddingModel> LanceDbVectorStore<M> { +impl<M: EmbeddingModel> LanceDbVectorIndex<M> { + /// Create an instance of `LanceDbVectorIndex` with an existing table and model. + /// Define the id field name of the table. + /// Define search parameters that will be used to perform vector searches on the table. + pub async fn new( + table: lancedb::Table, + model: M, + id_field: &str, + search_params: SearchParams, + ) -> Result<Self, lancedb::Error> { + Ok(Self { + table, + model, + id_field: id_field.to_string(), + search_params, + }) + } + /// Apply the search_params to the vector query. /// This is a helper function used by the methods `top_n` and `top_n_ids` of the `VectorStoreIndex` trait. fn build_query(&self, mut query: VectorQuery) -> VectorQuery { @@ -155,6 +110,10 @@ pub enum SearchType { } /// Parameters used to perform a vector search on a LanceDb table. +/// # Example +/// ``` +/// let search_params = SearchParams::default().distance_type(DistanceType::Cosine); +/// ``` #[derive(Debug, Clone, Default)] pub struct SearchParams { distance_type: Option<DistanceType>, @@ -215,26 +174,22 @@ impl SearchParams { } } -impl<M: EmbeddingModel> LanceDbVectorStore<M> { - /// Create an instance of `LanceDbVectorStore` with an existing table and model. - /// Define the id field name of the table. - /// Define search parameters that will be used to perform vector searches on the table. - pub async fn new( - table: lancedb::Table, - model: M, - id_field: &str, - search_params: SearchParams, - ) -> Result<Self, lancedb::Error> { - Ok(Self { - table, - model, - id_field: id_field.to_string(), - search_params, - }) - } -} - -impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorStore<M> { +impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorIndex<M> { + /// Implement the `top_n` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. + /// # Example + /// ``` + /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; + /// use rig::embeddings::EmbeddingModel; + /// + /// fn execute_search(table: lancedb::Table, model: EmbeddingModel) { + /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; + /// + /// // Query the index + /// let result = vector_store_index + /// .top_n::<String>("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// } + /// ``` async fn top_n<T: for<'a> Deserialize<'a> + Send>( &self, query: &str, @@ -269,6 +224,18 @@ impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorStore<M> .collect() } + /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. + /// # Example + /// ``` + /// fn execute_search(table: lancedb::Table, model: EmbeddingModel) { + /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; + /// + /// // Query the index + /// let result = vector_store_index + /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// } + /// ``` async fn top_n_ids( &self, query: &str, diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 3201c648..a803d385 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -7,134 +7,35 @@ use rig::{ }; use serde::Deserialize; +fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { + VectorStoreError::DatastoreError(Box::new(e)) +} + /// # Example /// ``` -/// use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; -/// use rig::providers::openai::TEXT_EMBEDDING_ADA_002; -/// use serde::{Deserialize, Serialize}; -/// use std::env; - -/// use rig::Embeddable; -/// use rig::{ -/// embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, -/// }; /// use rig_mongodb::{MongoDbVectorStore, SearchParams}; - -/// // Shape of data that needs to be RAG'ed. -/// // The definition field will be used to generate embeddings. -/// #[derive(Embeddable, Clone, Deserialize, Debug)] -/// struct FakeDefinition { -/// #[serde(rename = "_id")] -/// id: String, -/// #[embed] -/// definition: String, -/// } - -/// #[derive(Clone, Deserialize, Debug, Serialize)] -/// struct Link { -/// word: String, -/// link: String, -/// } - -/// // Shape of the document to be stored in MongoDB, with embeddings. -/// #[derive(Serialize, Debug)] +/// use rig::embeddings::EmbeddingModel; +/// +/// #[derive(serde::Serialize, Debug)] /// struct Document { /// #[serde(rename = "_id")] /// id: String, /// definition: String, /// embedding: Vec<f64>, /// } -/// // Initialize OpenAI client -/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); -/// let openai_client = Client::new(&openai_api_key); - -/// // Initialize MongoDB client -/// let mongodb_connection_string = -/// env::var("MONGODB_CONNECTION_STRING").expect("MONGODB_CONNECTION_STRING not set"); -/// let options = ClientOptions::parse(mongodb_connection_string) -/// .await -/// .expect("MongoDB connection string should be valid"); - -/// let mongodb_client = -/// MongoClient::with_options(options).expect("MongoDB client options should be valid"); - -/// // Initialize MongoDB vector store -/// let collection: Collection<Document> = mongodb_client -/// .database("knowledgebase") -/// .collection("context"); - -/// // Select the embedding model and generate our embeddings -/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - -/// let fake_definitions = vec![ -/// FakeDefinition { -/// id: "doc0".to_string(), -/// definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), -/// }, -/// FakeDefinition { -/// id: "doc1".to_string(), -/// definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), -/// }, -/// FakeDefinition { -/// id: "doc2".to_string(), -/// definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), -/// } -/// ]; - -/// let embeddings = EmbeddingsBuilder::new(model.clone()) -/// .documents(fake_definitions)? -/// .build() -/// .await?; - -/// let mongo_documents = embeddings -/// .iter() -/// .map( -/// |(FakeDefinition { id, definition, .. }, embedding)| Document { -/// id: id.clone(), -/// definition: definition.clone(), -/// embedding: embedding.first().vec.clone(), -/// }, -/// ) -/// .collect::<Vec<_>>(); - -/// match collection.insert_many(mongo_documents, None).await { -/// Ok(_) => println!("Documents added successfully"), -/// Err(e) => println!("Error adding documents: {:?}", e), -/// }; - -/// // Create a vector index on our vector store. -/// // Note: a vector index called "vector_index" must exist on the MongoDB collection you are querying. -/// // IMPORTANT: Reuse the same model that was used to generate the embeddings -/// let index = MongoDbVectorStore::new(collection).index( -/// model, -/// "vector_index", -/// SearchParams::new("embedding"), -/// ); - -/// // Query the index -/// let results = index -/// .top_n::<FakeDefinition>("What is a linglingdong?", 1) -/// .await?; - -/// println!("Results: {:?}", results); - -/// let id_results = index -/// .top_n_ids("What is a linglingdong?", 1) -/// .await? -/// .into_iter() -/// .map(|(score, id)| (score, id)) -/// .collect::<Vec<_>>(); - -/// println!("ID results: {:?}", id_results); +/// +/// fn create_index(collection: mongodb::Collection<Document>, model: EmbeddingModel) { +/// let index = MongoDbVectorStore::new(collection).index( +/// model, +/// "vector_index", // <-- replace with the name of the index in your mongodb collection. +/// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. +/// ); +/// } /// ``` pub struct MongoDbVectorStore<C> { collection: mongodb::Collection<C>, } -fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { - VectorStoreError::DatastoreError(Box::new(e)) -} - impl<C> MongoDbVectorStore<C> { /// Create a new `MongoDbVectorStore` from a MongoDB collection. pub fn new(collection: mongodb::Collection<C>) -> Self { @@ -263,6 +164,40 @@ impl SearchParams { impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex for MongoDbVectorIndex<M, C> { + /// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`. + /// # Example + /// ``` + /// use rig_mongodb::{MongoDbVectorStore, SearchParams}; + /// use rig::embeddings::EmbeddingModel; + /// + /// #[derive(serde::Serialize, Debug)] + /// struct Document { + /// #[serde(rename = "_id")] + /// id: String, + /// definition: String, + /// embedding: Vec<f64>, + /// } + /// + /// #[derive(serde::Deserialize, Debug)] + /// struct Definition { + /// #[serde(rename = "_id")] + /// id: String, + /// definition: String, + /// } + /// + /// fn execute_search(collection: mongodb::Collection<Document>, model: EmbeddingModel) { + /// let vector_store_index = MongoDbVectorStore::new(collection).index( + /// model, + /// "vector_index", // <-- replace with the name of the index in your mongodb collection. + /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. + /// ); + /// + /// // Query the index + /// vector_store_index + /// .top_n::<Definition>("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// } + /// ``` async fn top_n<T: for<'a> Deserialize<'a> + Send>( &self, query: &str, @@ -303,6 +238,33 @@ impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex Ok(results) } + /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`. + /// # Example + /// ``` + /// use rig_mongodb::{MongoDbVectorStore, SearchParams}; + /// use rig::embeddings::EmbeddingModel; + /// + /// #[derive(serde::Serialize, Debug)] + /// struct Document { + /// #[serde(rename = "_id")] + /// id: String, + /// definition: String, + /// embedding: Vec<f64>, + /// } + /// + /// fn execute_search(collection: mongodb::Collection<Document>, model: EmbeddingModel) { + /// let vector_store_index = MongoDbVectorStore::new(collection).index( + /// model, + /// "vector_index", // <-- replace with the name of the index in your mongodb collection. + /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. + /// ); + /// + /// // Query the index + /// vector_store_index + /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// } + /// ``` async fn top_n_ids( &self, query: &str, From ed9e038876b64c76d82f4daf23b116767d483316 Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Tue, 22 Oct 2024 16:36:07 -0400 Subject: [PATCH 74/91] docs: add more doc tests --- Cargo.lock | 36 ++++++++++++++++ rig-core/Cargo.toml | 3 +- rig-core/src/embeddings/builder.rs | 12 +++++- rig-core/src/embeddings/embeddable.rs | 61 +++++++++++++------------- rig-core/src/embeddings/mod.rs | 1 + rig-core/src/embeddings/tool.rs | 62 ++++++++++++++++++++++++++- rig-core/src/lib.rs | 2 +- 7 files changed, 143 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 75f67709..1a0a5eed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -343,6 +343,28 @@ dependencies = [ "syn 2.0.79", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "async-trait" version = "0.1.83" @@ -4003,6 +4025,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", + "tokio-test", "tracing", "tracing-subscriber", ] @@ -5122,6 +5145,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-test" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.7.12" diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index ea910406..ff2df2da 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -29,6 +29,7 @@ rig-derive = { path = "./rig-core-derive", optional = true } anyhow = "1.0.75" tokio = { version = "1.34.0", features = ["full"] } tracing-subscriber = "0.3.18" +tokio-test = "0.4.4" [features] derive = ["dep:rig-derive"] @@ -47,4 +48,4 @@ required-features = ["derive"] [[example]] name = "vector_search_cohere" -required-features = ["derive"] \ No newline at end of file +required-features = ["derive"] diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 887dbf0e..fed2ef06 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -73,6 +73,7 @@ impl<M: EmbeddingModel, T: Embeddable> EmbeddingsBuilder<M, T> { /// /// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); /// +/// # tokio_test::block_on(async { /// let embeddings = EmbeddingsBuilder::new(model.clone()) /// .documents(vec![ /// FakeDefinition { @@ -99,9 +100,16 @@ impl<M: EmbeddingModel, T: Embeddable> EmbeddingsBuilder<M, T> { /// "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() /// ] /// }, -/// ])? +/// ]) +/// .unwrap() /// .build() -/// .await?; +/// .await +/// .unwrap(); +/// +/// assert_eq!(embeddings.iter().any(|(doc, embeddings)| doc.id == "doc0" && embeddings.len() == 2), true); +/// assert_eq!(embeddings.iter().any(|(doc, embeddings)| doc.id == "doc1" && embeddings.len() == 2), true); +/// assert_eq!(embeddings.iter().any(|(doc, embeddings)| doc.id == "doc2" && embeddings.len() == 2), true); +/// }) /// ``` impl<M: EmbeddingModel, T: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M, T> { /// Generate embeddings for all documents in the builder. diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index f5a69fd6..1d70ffcf 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -1,33 +1,4 @@ //! The module defines the [Embeddable] trait, which must be implemented for types that can be embedded. -//! # Example -//! ```rust -//! use std::env; -//! -//! use serde::{Deserialize, Serialize}; -//! use rig::OneOrMany; -//! -//! struct FakeDefinition { -//! id: String, -//! word: String, -//! definition: String, -//! } -//! -//! let fake_definition = FakeDefinition { -//! id: "doc1".to_string(), -//! word: "hello".to_string(), -//! definition: "used as a greeting or to begin a conversation".to_string() -//! }; -//! -//! impl Embeddable for FakeDefinition { -//! type Error = anyhow::Error; -//! -//! fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { -//! // Embeddigns only need to be generated for `definition` field. -//! // Select it from the struct and return it as a single item. -//! Ok(OneOrMany::one(self.definition.clone())) -//! } -//! } -//! ``` use crate::one_or_many::OneOrMany; @@ -46,6 +17,38 @@ impl EmbeddableError { /// Trait for types that can be embedded. /// The `embeddable` method returns a `OneOrMany<String>` which contains strings for which embeddings will be generated by the embeddings builder. /// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. +/// # Example +/// ```rust +/// use std::env; +/// +/// use serde::{Deserialize, Serialize}; +/// use rig::{OneOrMany, EmptyListError, Embeddable}; +/// +/// struct FakeDefinition { +/// id: String, +/// word: String, +/// definitions: String, +/// } +/// +/// let fake_definition = FakeDefinition { +/// id: "doc1".to_string(), +/// word: "rock".to_string(), +/// definitions: "the solid mineral material forming part of the surface of the earth, a precious stone".to_string() +/// }; +/// +/// impl Embeddable for FakeDefinition { +/// type Error = EmptyListError; +/// +/// fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { +/// // Embeddings only need to be generated for `definition` field. +/// // Split the definitions by comma and collect them into a vector of strings. +/// // That way, different embeddings can be generated for each definition in the definitions string. +/// let definitions = self.definitions.split(",").collect::<Vec<_>>().into_iter().map(|s| s.to_string()).collect(); +/// +/// OneOrMany::many(definitions) +/// } +/// } +/// ``` pub trait Embeddable { type Error: std::error::Error + Sync + Send + 'static; diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 763e0f30..b8ad9b62 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -11,3 +11,4 @@ pub mod tool; pub use builder::EmbeddingsBuilder; pub use embeddable::Embeddable; pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; +pub use tool::EmbeddableTool; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index 139b11b8..c7c23b87 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -20,7 +20,67 @@ impl Embeddable for EmbeddableTool { } impl EmbeddableTool { - /// Convert item that implements ToolEmbedding to an EmbeddableTool. + /// Convert item that implements ToolEmbeddingDyn to an EmbeddableTool. + /// # Example + /// ```rust + /// use rig::{ + /// completion::ToolDefinition, + /// embeddings::EmbeddableTool, + /// tool::{Tool, ToolEmbedding, ToolEmbeddingDyn}, + /// }; + /// use serde_json::json; + /// + /// #[derive(Debug, thiserror::Error)] + /// #[error("Math error")] + /// struct NothingError; + /// + /// #[derive(Debug, thiserror::Error)] + /// #[error("Init error")] + /// struct InitError; + /// + /// struct Nothing; + /// impl Tool for Nothing { + /// const NAME: &'static str = "nothing"; + /// + /// type Error = NothingError; + /// type Args = (); + /// type Output = (); + /// + /// async fn definition(&self, _prompt: String) -> ToolDefinition { + /// serde_json::from_value(json!({ + /// "name": "nothing", + /// "description": "nothing", + /// "parameters": {} + /// })) + /// .expect("Tool Definition") + /// } + /// + /// async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> { + /// Ok(()) + /// } + /// } + /// + /// impl ToolEmbedding for Nothing { + /// type InitError = InitError; + /// type Context = (); + /// type State = (); + /// + /// fn init(_state: Self::State, _context: Self::Context) -> Result<Self, Self::InitError> { + /// Ok(Nothing) + /// } + /// + /// fn embedding_docs(&self) -> Vec<String> { + /// vec!["Do nothing.".into()] + /// } + /// + /// fn context(&self) -> Self::Context {} + /// } + /// + /// let tool = EmbeddableTool::try_from(&Nothing).unwrap(); + /// + /// assert_eq!(tool.name, "nothing".to_string()); + /// assert_eq!(tool.embedding_docs, vec!["Do nothing.".to_string()]); + /// ``` pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result<Self, EmbeddableError> { Ok(EmbeddableTool { name: tool.name(), diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 6b337073..2ef24051 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -79,7 +79,7 @@ pub mod vector_store; // Re-export commonly used types and traits pub use embeddings::Embeddable; -pub use one_or_many::OneOrMany; +pub use one_or_many::{EmptyListError, OneOrMany}; #[cfg(feature = "derive")] pub use rig_derive::Embeddable; From c502ea5a623817126fc5043051264beaed54585c Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 23 Oct 2024 10:32:02 -0400 Subject: [PATCH 75/91] feat: rename Embeddable trait to ExtractEmbeddingFields --- rig-core/examples/rag.rs | 6 +- rig-core/examples/vector_search.rs | 4 +- rig-core/examples/vector_search_cohere.rs | 4 +- rig-core/rig-core-derive/src/basic.rs | 4 +- rig-core/rig-core-derive/src/embeddable.rs | 16 +- rig-core/rig-core-derive/src/lib.rs | 2 +- rig-core/src/embeddings/builder.rs | 40 ++--- rig-core/src/embeddings/embeddable.rs | 163 ------------------ .../embeddings/extract_embedding_fields.rs | 163 ++++++++++++++++++ rig-core/src/embeddings/mod.rs | 4 +- rig-core/src/embeddings/tool.rs | 16 +- rig-core/src/lib.rs | 4 +- rig-core/src/providers/cohere.rs | 4 +- rig-core/src/providers/openai.rs | 4 +- rig-core/src/tool.rs | 4 +- rig-core/tests/embeddable_macro.rs | 30 ++-- rig-lancedb/examples/fixtures/lib.rs | 4 +- rig-mongodb/examples/vector_search_mongodb.rs | 4 +- 18 files changed, 238 insertions(+), 238 deletions(-) delete mode 100644 rig-core/src/embeddings/embeddable.rs create mode 100644 rig-core/src/embeddings/extract_embedding_fields.rs diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index d03902d1..ab2f7767 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -5,14 +5,14 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, - Embeddable, + ExtractEmbeddingFields, }; use serde::Serialize; // Shape of data that needs to be RAG'ed. -// A vector search needs to be performed on the definitions, so we derive the `Embeddable` trait for `FakeDefinition` +// A vector search needs to be performed on the definitions, so we derive the `ExtractEmbeddingFields` trait for `FakeDefinition` // and tag that field with `#[embed]`. -#[derive(Embeddable, Serialize, Clone, Debug, Eq, PartialEq, Default)] +#[derive(ExtractEmbeddingFields, Serialize, Clone, Debug, Eq, PartialEq, Default)] struct FakeDefinition { id: String, #[embed] diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index b40f271c..36bb8d7e 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -4,13 +4,13 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - Embeddable, + ExtractEmbeddingFields, }; use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, word: String, diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index da1e474b..003d39f5 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -4,13 +4,13 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - Embeddable, + ExtractEmbeddingFields, }; use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, word: String, diff --git a/rig-core/rig-core-derive/src/basic.rs b/rig-core/rig-core-derive/src/basic.rs index 86bb13ad..39b72018 100644 --- a/rig-core/rig-core-derive/src/basic.rs +++ b/rig-core/rig-core-derive/src/basic.rs @@ -15,11 +15,11 @@ pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item }) } -/// Adds bounds to where clause that force all fields tagged with #[embed] to implement the Embeddable trait. +/// Adds bounds to where clause that force all fields tagged with #[embed] to implement the ExtractEmbeddingFields trait. pub(crate) fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { let where_clause = generics.make_where_clause(); where_clause.predicates.push(parse_quote! { - #field_type: Embeddable + #field_type: ExtractEmbeddingFields }); } diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index 8336c9e0..27dac489 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -18,7 +18,7 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu let (custom_targets, custom_target_size) = data_struct.custom()?; // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. - // ie. do not implement Embeddable trait for the struct. + // ie. do not implement `ExtractEmbeddingFields` trait for the struct. if basic_target_size + custom_target_size == 0 { return Err(syn::Error::new_spanned( name, @@ -34,7 +34,7 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu _ => { return Err(syn::Error::new_spanned( input, - "Embeddable derive macro should only be used on structs", + "ExtractEmbeddingFields derive macro should only be used on structs", )) } }; @@ -42,18 +42,18 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { - // Note: Embeddable trait is imported with the macro. + // Note: `ExtractEmbeddingFields` trait is imported with the macro. - impl #impl_generics Embeddable for #name #ty_generics #where_clause { - type Error = rig::embeddings::embeddable::EmbeddableError; + impl #impl_generics ExtractEmbeddingFields for #name #ty_generics #where_clause { + type Error = rig::embeddings::embeddable::ExtractEmbeddingFieldsError; - fn embeddable(&self) -> Result<rig::OneOrMany<String>, Self::Error> { + fn extract_embedding_fields(&self) -> Result<rig::OneOrMany<String>, Self::Error> { #target_stream; rig::OneOrMany::merge( embed_targets.into_iter() .collect::<Result<Vec<_>, _>>()? - ).map_err(rig::embeddings::embeddable::EmbeddableError::new) + ).map_err(rig::embeddings::embeddable::ExtractEmbeddingFieldsError::new) } } }; @@ -87,7 +87,7 @@ impl StructParser for DataStruct { if !embed_targets.is_empty() { ( quote! { - vec![#(#embed_targets.embeddable()),*] + vec![#(#embed_targets.extract_embedding_fields()),*] }, embed_targets.len(), ) diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index d28a0d78..042f7ca9 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -11,7 +11,7 @@ pub(crate) const EMBED: &str = "embed"; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html -#[proc_macro_derive(Embeddable, attributes(embed))] +#[proc_macro_derive(ExtractEmbeddingFields, attributes(embed))] pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index fed2ef06..b6138ef5 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -1,22 +1,22 @@ //! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. -//! Only types that implement the [Embeddable] trait can be added to the [EmbeddingsBuilder]. +//! Only types that implement the [ExtractEmbeddingFields] trait can be added to the [EmbeddingsBuilder]. use std::{cmp::max, collections::HashMap}; use futures::{stream, StreamExt, TryStreamExt}; use crate::{ - embeddings::{Embeddable, Embedding, EmbeddingError, EmbeddingModel}, + embeddings::{ExtractEmbeddingFields, Embedding, EmbeddingError, EmbeddingModel}, OneOrMany, }; /// Builder for creating a collection of embeddings. -pub struct EmbeddingsBuilder<M: EmbeddingModel, T: Embeddable> { +pub struct EmbeddingsBuilder<M: EmbeddingModel, T: ExtractEmbeddingFields> { model: M, documents: Vec<(T, OneOrMany<String>)>, } -impl<M: EmbeddingModel, T: Embeddable> EmbeddingsBuilder<M, T> { +impl<M: EmbeddingModel, T: ExtractEmbeddingFields> EmbeddingsBuilder<M, T> { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { @@ -25,18 +25,18 @@ impl<M: EmbeddingModel, T: Embeddable> EmbeddingsBuilder<M, T> { } } - /// Add a document that implements `Embeddable` to the builder. + /// Add a document that implements `ExtractEmbeddingFields` to the builder. pub fn document(mut self, document: T) -> Result<Self, T::Error> { - let embed_targets = document.embeddable()?; + let embed_targets = document.extract_embedding_fields()?; self.documents.push((document, embed_targets)); Ok(self) } - /// Add many documents that implement `Embeddable` to the builder. + /// Add many documents that implement `ExtractEmbeddingFields` to the builder. pub fn documents(mut self, documents: Vec<T>) -> Result<Self, T::Error> { for doc in documents.into_iter() { - let embed_targets = doc.embeddable()?; + let embed_targets = doc.extract_embedding_fields()?; self.documents.push((doc, embed_targets)); } @@ -53,13 +53,13 @@ impl<M: EmbeddingModel, T: Embeddable> EmbeddingsBuilder<M, T> { /// embeddings::EmbeddingsBuilder, /// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, /// vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, -/// Embeddable, +/// ExtractEmbeddingFields, /// }; /// use serde::{Deserialize, Serialize}; /// /// // Shape of data that needs to be RAG'ed. /// // The definition field will be used to generate embeddings. -/// #[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +/// #[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] /// struct FakeDefinition { /// id: String, /// word: String, @@ -111,7 +111,7 @@ impl<M: EmbeddingModel, T: Embeddable> EmbeddingsBuilder<M, T> { /// assert_eq!(embeddings.iter().any(|(doc, embeddings)| doc.id == "doc2" && embeddings.len() == 2), true); /// }) /// ``` -impl<M: EmbeddingModel, T: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M, T> { +impl<M: EmbeddingModel, T: ExtractEmbeddingFields + Send + Sync + Clone> EmbeddingsBuilder<M, T> { /// Generate embeddings for all documents in the builder. /// Returns a vector of tuples, where the first element is the document and the second element is the embeddings (either one embedding or many). pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> { @@ -179,8 +179,8 @@ impl<M: EmbeddingModel, T: Embeddable + Send + Sync + Clone> EmbeddingsBuilder<M #[cfg(test)] mod tests { use crate::{ - embeddings::{embeddable::EmbeddableError, Embedding, EmbeddingModel}, - Embeddable, + embeddings::{extract_embedding_fields::ExtractEmbeddingFieldsError, Embedding, EmbeddingModel}, + ExtractEmbeddingFields, }; use super::EmbeddingsBuilder; @@ -215,11 +215,11 @@ mod tests { definitions: Vec<String>, } - impl Embeddable for FakeDefinition { - type Error = EmbeddableError; + impl ExtractEmbeddingFields for FakeDefinition { + type Error = ExtractEmbeddingFieldsError; - fn embeddable(&self) -> Result<crate::OneOrMany<String>, Self::Error> { - crate::OneOrMany::many(self.definitions.clone()).map_err(EmbeddableError::new) + fn extract_embedding_fields(&self) -> Result<crate::OneOrMany<String>, Self::Error> { + crate::OneOrMany::many(self.definitions.clone()).map_err(ExtractEmbeddingFieldsError::new) } } @@ -261,10 +261,10 @@ mod tests { definition: String, } - impl Embeddable for FakeDefinitionSingle { - type Error = EmbeddableError; + impl ExtractEmbeddingFields for FakeDefinitionSingle { + type Error = ExtractEmbeddingFieldsError; - fn embeddable(&self) -> Result<crate::OneOrMany<String>, Self::Error> { + fn extract_embedding_fields(&self) -> Result<crate::OneOrMany<String>, Self::Error> { Ok(crate::OneOrMany::one(self.definition.clone())) } } diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs deleted file mode 100644 index 1d70ffcf..00000000 --- a/rig-core/src/embeddings/embeddable.rs +++ /dev/null @@ -1,163 +0,0 @@ -//! The module defines the [Embeddable] trait, which must be implemented for types that can be embedded. - -use crate::one_or_many::OneOrMany; - -/// Error type used for when the `embeddable` method fails. -/// Used by default implementations of `Embeddable` for common types. -#[derive(Debug, thiserror::Error)] -#[error("{0}")] -pub struct EmbeddableError(#[from] Box<dyn std::error::Error + Send + Sync>); - -impl EmbeddableError { - pub fn new<E: std::error::Error + Send + Sync + 'static>(error: E) -> Self { - EmbeddableError(Box::new(error)) - } -} - -/// Trait for types that can be embedded. -/// The `embeddable` method returns a `OneOrMany<String>` which contains strings for which embeddings will be generated by the embeddings builder. -/// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. -/// # Example -/// ```rust -/// use std::env; -/// -/// use serde::{Deserialize, Serialize}; -/// use rig::{OneOrMany, EmptyListError, Embeddable}; -/// -/// struct FakeDefinition { -/// id: String, -/// word: String, -/// definitions: String, -/// } -/// -/// let fake_definition = FakeDefinition { -/// id: "doc1".to_string(), -/// word: "rock".to_string(), -/// definitions: "the solid mineral material forming part of the surface of the earth, a precious stone".to_string() -/// }; -/// -/// impl Embeddable for FakeDefinition { -/// type Error = EmptyListError; -/// -/// fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { -/// // Embeddings only need to be generated for `definition` field. -/// // Split the definitions by comma and collect them into a vector of strings. -/// // That way, different embeddings can be generated for each definition in the definitions string. -/// let definitions = self.definitions.split(",").collect::<Vec<_>>().into_iter().map(|s| s.to_string()).collect(); -/// -/// OneOrMany::many(definitions) -/// } -/// } -/// ``` -pub trait Embeddable { - type Error: std::error::Error + Sync + Send + 'static; - - fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error>; -} - -// ================================================================ -// Implementations of Embeddable for common types -// ================================================================ -impl Embeddable for String { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.clone())) - } -} - -impl Embeddable for i8 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for i16 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for i32 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for i64 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for i128 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for f32 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for f64 { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for bool { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for char { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl Embeddable for serde_json::Value { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one( - serde_json::to_string(self).map_err(EmbeddableError::new)?, - )) - } -} - -impl<T: Embeddable> Embeddable for Vec<T> { - type Error = EmbeddableError; - - fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - let items = self - .iter() - .map(|item| item.embeddable()) - .collect::<Result<Vec<_>, _>>() - .map_err(EmbeddableError::new)?; - - OneOrMany::merge(items).map_err(EmbeddableError::new) - } -} diff --git a/rig-core/src/embeddings/extract_embedding_fields.rs b/rig-core/src/embeddings/extract_embedding_fields.rs new file mode 100644 index 00000000..e62d43c6 --- /dev/null +++ b/rig-core/src/embeddings/extract_embedding_fields.rs @@ -0,0 +1,163 @@ +//! The module defines the [ExtractEmbeddingFields] trait, which must be implemented for types that can be embedded. + +use crate::one_or_many::OneOrMany; + +/// Error type used for when the `extract_embedding_fields` method fails. +/// Used by default implementations of `ExtractEmbeddingFields` for common types. +#[derive(Debug, thiserror::Error)] +#[error("{0}")] +pub struct ExtractEmbeddingFieldsError(#[from] Box<dyn std::error::Error + Send + Sync>); + +impl ExtractEmbeddingFieldsError { + pub fn new<E: std::error::Error + Send + Sync + 'static>(error: E) -> Self { + ExtractEmbeddingFieldsError(Box::new(error)) + } +} + +/// Derive this trait for structs whose fields need to be converted to vector embeddings. +/// The `extract_embedding_fields` method returns a `OneOrMany<String>`. This function extracts the fields that need to be embedded and returns them as a list of strings. +/// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. +/// # Example +/// ```rust +/// use std::env; +/// +/// use serde::{Deserialize, Serialize}; +/// use rig::{OneOrMany, EmptyListError, ExtractEmbeddingFields}; +/// +/// struct FakeDefinition { +/// id: String, +/// word: String, +/// definitions: String, +/// } +/// +/// let fake_definition = FakeDefinition { +/// id: "doc1".to_string(), +/// word: "rock".to_string(), +/// definitions: "the solid mineral material forming part of the surface of the earth, a precious stone".to_string() +/// }; +/// +/// impl ExtractEmbeddingFields for FakeDefinition { +/// type Error = EmptyListError; +/// +/// fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { +/// // Embeddings only need to be generated for `definition` field. +/// // Split the definitions by comma and collect them into a vector of strings. +/// // That way, different embeddings can be generated for each definition in the definitions string. +/// let definitions = self.definitions.split(",").collect::<Vec<_>>().into_iter().map(|s| s.to_string()).collect(); +/// +/// OneOrMany::many(definitions) +/// } +/// } +/// ``` +pub trait ExtractEmbeddingFields { + type Error: std::error::Error + Sync + Send + 'static; + + fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error>; +} + +// ================================================================ +// Implementations of ExtractEmbeddingFields for common types +// ================================================================ +impl ExtractEmbeddingFields for String { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::one(self.clone())) + } +} + +impl ExtractEmbeddingFields for i8 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for i16 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for i32 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for i64 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for i128 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for f32 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for f64 { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for bool { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for char { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::one(self.to_string())) + } +} + +impl ExtractEmbeddingFields for serde_json::Value { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { + Ok(OneOrMany::one( + serde_json::to_string(self).map_err(ExtractEmbeddingFieldsError::new)?, + )) + } +} + +impl<T: ExtractEmbeddingFields> ExtractEmbeddingFields for Vec<T> { + type Error = ExtractEmbeddingFieldsError; + + fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { + let items = self + .iter() + .map(|item| item.extract_embedding_fields()) + .collect::<Result<Vec<_>, _>>() + .map_err(ExtractEmbeddingFieldsError::new)?; + + OneOrMany::merge(items).map_err(ExtractEmbeddingFieldsError::new) + } +} diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index b8ad9b62..f5e9ede5 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -4,11 +4,11 @@ //! and document similarity. pub mod builder; -pub mod embeddable; +pub mod extract_embedding_fields; pub mod embedding; pub mod tool; pub use builder::EmbeddingsBuilder; -pub use embeddable::Embeddable; +pub use extract_embedding_fields::ExtractEmbeddingFields; pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; pub use tool::EmbeddableTool; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index c7c23b87..7550038f 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -1,7 +1,7 @@ -use crate::{tool::ToolEmbeddingDyn, Embeddable, OneOrMany}; +use crate::{tool::ToolEmbeddingDyn, ExtractEmbeddingFields, OneOrMany}; use serde::Serialize; -use super::embeddable::EmbeddableError; +use super::extract_embedding_fields::ExtractEmbeddingFieldsError; /// Used by EmbeddingsBuilder to embed anything that implements ToolEmbedding. #[derive(Clone, Serialize, Default, Eq, PartialEq)] @@ -11,11 +11,11 @@ pub struct EmbeddableTool { pub embedding_docs: Vec<String>, } -impl Embeddable for EmbeddableTool { - type Error = EmbeddableError; +impl ExtractEmbeddingFields for EmbeddableTool { + type Error = ExtractEmbeddingFieldsError; - fn embeddable(&self) -> Result<OneOrMany<String>, Self::Error> { - OneOrMany::many(self.embedding_docs.clone()).map_err(EmbeddableError::new) + fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { + OneOrMany::many(self.embedding_docs.clone()).map_err(ExtractEmbeddingFieldsError::new) } } @@ -81,10 +81,10 @@ impl EmbeddableTool { /// assert_eq!(tool.name, "nothing".to_string()); /// assert_eq!(tool.embedding_docs, vec!["Do nothing.".to_string()]); /// ``` - pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result<Self, EmbeddableError> { + pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result<Self, ExtractEmbeddingFieldsError> { Ok(EmbeddableTool { name: tool.name(), - context: tool.context().map_err(EmbeddableError::new)?, + context: tool.context().map_err(ExtractEmbeddingFieldsError::new)?, embedding_docs: tool.embedding_docs(), }) } diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 2ef24051..5383b34e 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -78,8 +78,8 @@ pub mod tool; pub mod vector_store; // Re-export commonly used types and traits -pub use embeddings::Embeddable; +pub use embeddings::ExtractEmbeddingFields; pub use one_or_many::{EmptyListError, OneOrMany}; #[cfg(feature = "derive")] -pub use rig_derive::Embeddable; +pub use rig_derive::ExtractEmbeddingFields; diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 8f8eefd4..a6d8f00b 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -15,7 +15,7 @@ use crate::{ completion::{self, CompletionError}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, Embeddable, + json_utils, ExtractEmbeddingFields, }; use schemars::JsonSchema; @@ -85,7 +85,7 @@ impl Client { EmbeddingModel::new(self.clone(), model, input_type, ndims) } - pub fn embeddings<D: Embeddable>( + pub fn embeddings<D: ExtractEmbeddingFields>( &self, model: &str, input_type: &str, diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index b20df22f..0bfeac3c 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -13,7 +13,7 @@ use crate::{ completion::{self, CompletionError, CompletionRequest}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, Embeddable, + json_utils, ExtractEmbeddingFields, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -121,7 +121,7 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings<D: Embeddable>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> { + pub fn embeddings<D: ExtractEmbeddingFields>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> { EmbeddingsBuilder::new(self.embedding_model(model)) } diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index e92896b8..528faba5 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::{ completion::{self, ToolDefinition}, - embeddings::{embeddable::EmbeddableError, tool::EmbeddableTool}, + embeddings::{extract_embedding_fields::ExtractEmbeddingFieldsError, tool::EmbeddableTool}, }; #[derive(Debug, thiserror::Error)] @@ -330,7 +330,7 @@ impl ToolSet { /// Convert tools in self to objects of type EmbeddableTool. /// This is necessary because when adding tools to the EmbeddingBuilder because all /// documents added to the builder must all be of the same type. - pub fn embedabble_tools(&self) -> Result<Vec<EmbeddableTool>, EmbeddableError> { + pub fn embedabble_tools(&self) -> Result<Vec<EmbeddableTool>, ExtractEmbeddingFieldsError> { self.tools .values() .filter_map(|tool_type| { diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/embeddable_macro.rs index cbc76c80..5f8891ff 100644 --- a/rig-core/tests/embeddable_macro.rs +++ b/rig-core/tests/embeddable_macro.rs @@ -1,14 +1,14 @@ -use rig::embeddings::embeddable::EmbeddableError; -use rig::{Embeddable, OneOrMany}; +use rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError; +use rig::{ExtractEmbeddingFields, OneOrMany}; use serde::Serialize; -fn serialize(definition: Definition) -> Result<OneOrMany<String>, EmbeddableError> { +fn serialize(definition: Definition) -> Result<OneOrMany<String>, ExtractEmbeddingFieldsError> { Ok(OneOrMany::one( - serde_json::to_string(&definition).map_err(EmbeddableError::new)?, + serde_json::to_string(&definition).map_err(ExtractEmbeddingFieldsError::new)?, )) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct FakeDefinition { id: String, word: String, @@ -41,7 +41,7 @@ fn test_custom_embed() { ); assert_eq!( - fake_definition.embeddable().unwrap(), + fake_definition.extract_embedding_fields().unwrap(), OneOrMany::one( "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() ) @@ -49,7 +49,7 @@ fn test_custom_embed() { ) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct FakeDefinition2 { id: String, #[embed] @@ -76,17 +76,17 @@ fn test_custom_and_basic_embed() { ); assert_eq!( - fake_definition.embeddable().unwrap().first(), + fake_definition.extract_embedding_fields().unwrap().first(), "house".to_string() ); assert_eq!( - fake_definition.embeddable().unwrap().rest(), + fake_definition.extract_embedding_fields().unwrap().rest(), vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] ) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct FakeDefinition3 { id: String, word: String, @@ -109,12 +109,12 @@ fn test_single_embed() { ); assert_eq!( - fake_definition.embeddable().unwrap(), + fake_definition.extract_embedding_fields().unwrap(), OneOrMany::one(definition) ) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct Company { id: String, company: String, @@ -132,7 +132,7 @@ fn test_multiple_embed_strings() { println!("Company: {}, {}", company.id, company.company); - let result = company.embeddable().unwrap(); + let result = company.extract_embedding_fields().unwrap(); assert_eq!( result, @@ -153,7 +153,7 @@ fn test_multiple_embed_strings() { ) } -#[derive(Embeddable)] +#[derive(ExtractEmbeddingFields)] struct Company2 { id: String, #[embed] @@ -173,7 +173,7 @@ fn test_multiple_embed_tags() { println!("Company: {}", company.id); assert_eq!( - company.embeddable().unwrap(), + company.extract_embedding_fields().unwrap(), OneOrMany::many(vec![ "Google".to_string(), "25".to_string(), diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index d6e02a5a..1a9089a5 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -3,10 +3,10 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; use rig::embeddings::Embedding; -use rig::{Embeddable, OneOrMany}; +use rig::{ExtractEmbeddingFields, OneOrMany}; use serde::Deserialize; -#[derive(Embeddable, Clone, Deserialize, Debug)] +#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug)] pub struct FakeDefinition { pub id: String, #[embed] diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index b095c060..cc16d4bc 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -3,7 +3,7 @@ use rig::providers::openai::TEXT_EMBEDDING_ADA_002; use serde::{Deserialize, Serialize}; use std::env; -use rig::Embeddable; +use rig::ExtractEmbeddingFields; use rig::{ embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, }; @@ -11,7 +11,7 @@ use rig_mongodb::{MongoDbVectorStore, SearchParams}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug)] +#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug)] struct FakeDefinition { #[serde(rename = "_id")] id: String, From ebc6b81e0190adc25a6622fa8c620a9962eaf78e Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Wed, 23 Oct 2024 10:38:51 -0400 Subject: [PATCH 76/91] feat: rename macro files, cargo fmt --- rig-core/Cargo.toml | 2 +- .../src/{embeddable.rs => extract_embedding_fields.rs} | 4 ++-- rig-core/rig-core-derive/src/lib.rs | 4 ++-- rig-core/src/embeddings/builder.rs | 9 ++++++--- rig-core/src/embeddings/mod.rs | 4 ++-- rig-core/src/providers/openai.rs | 5 ++++- ...ddable_macro.rs => extract_embedding_fields_macro.rs} | 0 7 files changed, 17 insertions(+), 11 deletions(-) rename rig-core/rig-core-derive/src/{embeddable.rs => extract_embedding_fields.rs} (95%) rename rig-core/tests/{embeddable_macro.rs => extract_embedding_fields_macro.rs} (100%) diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index ff2df2da..ed666fe4 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -35,7 +35,7 @@ tokio-test = "0.4.4" derive = ["dep:rig-derive"] [[test]] -name = "embeddable_macro" +name = "extract_embedding_fields_macro" required-features = ["derive"] [[example]] diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/extract_embedding_fields.rs similarity index 95% rename from rig-core/rig-core-derive/src/embeddable.rs rename to rig-core/rig-core-derive/src/extract_embedding_fields.rs index 27dac489..5c21f6b2 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/extract_embedding_fields.rs @@ -45,7 +45,7 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu // Note: `ExtractEmbeddingFields` trait is imported with the macro. impl #impl_generics ExtractEmbeddingFields for #name #ty_generics #where_clause { - type Error = rig::embeddings::embeddable::ExtractEmbeddingFieldsError; + type Error = rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError; fn extract_embedding_fields(&self) -> Result<rig::OneOrMany<String>, Self::Error> { #target_stream; @@ -53,7 +53,7 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu rig::OneOrMany::merge( embed_targets.into_iter() .collect::<Result<Vec<_>, _>>()? - ).map_err(rig::embeddings::embeddable::ExtractEmbeddingFieldsError::new) + ).map_err(rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError::new) } } }; diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index 042f7ca9..8ad69a65 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -4,7 +4,7 @@ use syn::{parse_macro_input, DeriveInput}; mod basic; mod custom; -mod embeddable; +mod extract_embedding_fields; pub(crate) const EMBED: &str = "embed"; @@ -15,7 +15,7 @@ pub(crate) const EMBED: &str = "embed"; pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); - embeddable::expand_derive_embedding(&mut input) + extract_embedding_fields::expand_derive_embedding(&mut input) .unwrap_or_else(syn::Error::into_compile_error) .into() } diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index b6138ef5..e884dc69 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -6,7 +6,7 @@ use std::{cmp::max, collections::HashMap}; use futures::{stream, StreamExt, TryStreamExt}; use crate::{ - embeddings::{ExtractEmbeddingFields, Embedding, EmbeddingError, EmbeddingModel}, + embeddings::{Embedding, EmbeddingError, EmbeddingModel, ExtractEmbeddingFields}, OneOrMany, }; @@ -179,7 +179,9 @@ impl<M: EmbeddingModel, T: ExtractEmbeddingFields + Send + Sync + Clone> Embeddi #[cfg(test)] mod tests { use crate::{ - embeddings::{extract_embedding_fields::ExtractEmbeddingFieldsError, Embedding, EmbeddingModel}, + embeddings::{ + extract_embedding_fields::ExtractEmbeddingFieldsError, Embedding, EmbeddingModel, + }, ExtractEmbeddingFields, }; @@ -219,7 +221,8 @@ mod tests { type Error = ExtractEmbeddingFieldsError; fn extract_embedding_fields(&self) -> Result<crate::OneOrMany<String>, Self::Error> { - crate::OneOrMany::many(self.definitions.clone()).map_err(ExtractEmbeddingFieldsError::new) + crate::OneOrMany::many(self.definitions.clone()) + .map_err(ExtractEmbeddingFieldsError::new) } } diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index f5e9ede5..37323cf5 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -4,11 +4,11 @@ //! and document similarity. pub mod builder; -pub mod extract_embedding_fields; pub mod embedding; +pub mod extract_embedding_fields; pub mod tool; pub use builder::EmbeddingsBuilder; -pub use extract_embedding_fields::ExtractEmbeddingFields; pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; +pub use extract_embedding_fields::ExtractEmbeddingFields; pub use tool::EmbeddableTool; diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 0bfeac3c..789c5282 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -121,7 +121,10 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings<D: ExtractEmbeddingFields>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> { + pub fn embeddings<D: ExtractEmbeddingFields>( + &self, + model: &str, + ) -> EmbeddingsBuilder<EmbeddingModel, D> { EmbeddingsBuilder::new(self.embedding_model(model)) } diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/extract_embedding_fields_macro.rs similarity index 100% rename from rig-core/tests/embeddable_macro.rs rename to rig-core/tests/extract_embedding_fields_macro.rs From 5c2d451c0e536670a28bd7399b5ff323f509620c Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 24 Oct 2024 09:27:12 -0400 Subject: [PATCH 77/91] PR; update docstrings, update `add_documents_with_id` function --- rig-core/examples/calculator_chatbot.rs | 2 +- rig-core/examples/rag.rs | 2 +- rig-core/examples/rag_dynamic_tools.rs | 2 +- rig-core/examples/vector_search.rs | 2 +- rig-core/examples/vector_search_cohere.rs | 2 +- rig-core/src/embeddings/builder.rs | 12 +----- rig-core/src/vector_store/in_memory_store.rs | 20 ++------- rig-lancedb/src/lib.rs | 34 +++++++-------- rig-mongodb/src/lib.rs | 45 ++++++++++---------- 9 files changed, 51 insertions(+), 70 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 723bfada..576491d5 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -252,7 +252,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, "name")? + .add_documents_with_id(embeddings, |tool| tool.name.clone())? .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index ab2f7767..9829089f 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -56,7 +56,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, "id")? + .add_documents_with_id(embeddings, |definition| definition.id.clone())? .index(embedding_model); let rag_agent = openai_client.agent("gpt-4") diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index c3a2c251..c140da15 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -161,7 +161,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, "name")? + .add_documents_with_id(embeddings, |tool| tool.name.clone())? .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 36bb8d7e..4777b42c 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -57,7 +57,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, "id")? + .add_documents_with_id(embeddings, |definition| definition.id.clone())? .index(model); let results = index diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 003d39f5..6d966004 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -58,7 +58,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, "id")? + .add_documents_with_id(embeddings, |definition| definition.id.clone())? .index(search_model); let results = index diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index e884dc69..b1b310bb 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -73,7 +73,6 @@ impl<M: EmbeddingModel, T: ExtractEmbeddingFields> EmbeddingsBuilder<M, T> { /// /// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); /// -/// # tokio_test::block_on(async { /// let embeddings = EmbeddingsBuilder::new(model.clone()) /// .documents(vec![ /// FakeDefinition { @@ -100,16 +99,9 @@ impl<M: EmbeddingModel, T: ExtractEmbeddingFields> EmbeddingsBuilder<M, T> { /// "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() /// ] /// }, -/// ]) -/// .unwrap() +/// ])? /// .build() -/// .await -/// .unwrap(); -/// -/// assert_eq!(embeddings.iter().any(|(doc, embeddings)| doc.id == "doc0" && embeddings.len() == 2), true); -/// assert_eq!(embeddings.iter().any(|(doc, embeddings)| doc.id == "doc1" && embeddings.len() == 2), true); -/// assert_eq!(embeddings.iter().any(|(doc, embeddings)| doc.id == "doc2" && embeddings.len() == 2), true); -/// }) +/// .await?; /// ``` impl<M: EmbeddingModel, T: ExtractEmbeddingFields + Send + Sync + Clone> EmbeddingsBuilder<M, T> { /// Generate embeddings for all documents in the builder. diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 5ab9d6e7..f4f067fe 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -76,28 +76,16 @@ impl<D: Serialize + Eq> InMemoryVectorStore<D> { Ok(self) } - /// Add documents to the store. Define the name of the field in the document that contains the id. + /// Add documents to the store. Define a function that takes as input the reference of the document and returns its id. /// Returns the store with the added documents. pub fn add_documents_with_id( mut self, documents: Vec<(D, OneOrMany<Embedding>)>, - id_field: &str, + id_f: fn(&D) -> String, ) -> Result<Self, VectorStoreError> { for (doc, embeddings) in documents { - if let serde_json::Value::Object(o) = - serde_json::to_value(&doc).map_err(VectorStoreError::JsonError)? - { - match o.get(id_field) { - Some(serde_json::Value::String(s)) => { - self.embeddings.insert(s.clone(), (doc, embeddings)); - } - _ => { - return Err(VectorStoreError::MissingIdError(format!( - "Document does not have a field {id_field}" - ))); - } - } - }; + let id = id_f(&doc); + self.embeddings.insert(id, (doc, embeddings)); } Ok(self) diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index eaaffbe3..46a6ef83 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -26,9 +26,9 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// -/// fn create_index(table: lancedb::Table, model: EmbeddingModel) { -/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; -/// } +/// let table: table: lancedb::Table = \*...\*; // <-- Replace with your lancedb table here. +/// let model: EmbeddingModel = \*...\*; // <-- Replace with your embedding model here. +/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// ``` pub struct LanceDbVectorIndex<M: EmbeddingModel> { /// Defines which model is used to generate embeddings for the vector store. @@ -181,14 +181,14 @@ impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorIndex<M> /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// - /// fn execute_search(table: lancedb::Table, model: EmbeddingModel) { - /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; + /// let table: table: lancedb::Table = \*...\*; // <-- Replace with your lancedb table here. + /// let model: EmbeddingModel = \*...\*; // <-- Replace with your embedding model here. + /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// - /// // Query the index - /// let result = vector_store_index - /// .top_n::<String>("My boss says I zindle too much, what does that mean?", 1) - /// .await?; - /// } + /// // Query the index + /// let result = vector_store_index + /// .top_n::<String>("My boss says I zindle too much, what does that mean?", 1) + /// .await?; /// ``` async fn top_n<T: for<'a> Deserialize<'a> + Send>( &self, @@ -227,14 +227,14 @@ impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorIndex<M> /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. /// # Example /// ``` - /// fn execute_search(table: lancedb::Table, model: EmbeddingModel) { - /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; + /// let table: table: lancedb::Table = \*...\*; // <-- Replace with your lancedb table here. + /// let model: EmbeddingModel = \*...\*; // <-- Replace with your embedding model here. + /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// - /// // Query the index - /// let result = vector_store_index - /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) - /// .await?; - /// } + /// // Query the index + /// let result = vector_store_index + /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) + /// .await?; /// ``` async fn top_n_ids( &self, diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index a803d385..50f67b11 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -185,18 +185,19 @@ impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex /// definition: String, /// } /// - /// fn execute_search(collection: mongodb::Collection<Document>, model: EmbeddingModel) { - /// let vector_store_index = MongoDbVectorStore::new(collection).index( - /// model, - /// "vector_index", // <-- replace with the name of the index in your mongodb collection. - /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. - /// ); + /// let collection: collection: mongodb::Collection<Document> = \* ... \*; // <-- replace with your mongodb collection. + /// let model: model: EmbeddingModel = \* ... \*; // <-- replace with your embedding model. /// - /// // Query the index - /// vector_store_index - /// .top_n::<Definition>("My boss says I zindle too much, what does that mean?", 1) - /// .await?; - /// } + /// let vector_store_index = MongoDbVectorStore::new(collection).index( + /// model, + /// "vector_index", // <-- replace with the name of the index in your mongodb collection. + /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. + /// ); + /// + /// // Query the index + /// vector_store_index + /// .top_n::<Definition>("My boss says I zindle too much, what does that mean?", 1) + /// .await?; /// ``` async fn top_n<T: for<'a> Deserialize<'a> + Send>( &self, @@ -252,18 +253,18 @@ impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex /// embedding: Vec<f64>, /// } /// - /// fn execute_search(collection: mongodb::Collection<Document>, model: EmbeddingModel) { - /// let vector_store_index = MongoDbVectorStore::new(collection).index( - /// model, - /// "vector_index", // <-- replace with the name of the index in your mongodb collection. - /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. - /// ); + /// let collection: collection: mongodb::Collection<Document> = \* ... \*; // <-- replace with your mongodb collection. + /// let model: model: EmbeddingModel = \* ... \*; // <-- replace with your embedding model. + /// let vector_store_index = MongoDbVectorStore::new(collection).index( + /// model, + /// "vector_index", // <-- replace with the name of the index in your mongodb collection. + /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. + /// ); /// - /// // Query the index - /// vector_store_index - /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) - /// .await?; - /// } + /// // Query the index + /// vector_store_index + /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) + /// .await?; /// ``` async fn top_n_ids( &self, From 55b42d86b5222a380ba378192af7ad48908fd13d Mon Sep 17 00:00:00 2001 From: Garance <garance.mary@gmail.com> Date: Thu, 24 Oct 2024 09:49:25 -0400 Subject: [PATCH 78/91] doc: fix doc linting --- rig-lancedb/src/lib.rs | 12 ++++++------ rig-mongodb/src/lib.rs | 22 +++++++++++----------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 46a6ef83..2eea2357 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -26,8 +26,8 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// -/// let table: table: lancedb::Table = \*...\*; // <-- Replace with your lancedb table here. -/// let model: EmbeddingModel = \*...\*; // <-- Replace with your embedding model here. +/// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. +/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// ``` pub struct LanceDbVectorIndex<M: EmbeddingModel> { @@ -181,8 +181,8 @@ impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorIndex<M> /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// - /// let table: table: lancedb::Table = \*...\*; // <-- Replace with your lancedb table here. - /// let model: EmbeddingModel = \*...\*; // <-- Replace with your embedding model here. + /// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// /// // Query the index @@ -227,8 +227,8 @@ impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorIndex<M> /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. /// # Example /// ``` - /// let table: table: lancedb::Table = \*...\*; // <-- Replace with your lancedb table here. - /// let model: EmbeddingModel = \*...\*; // <-- Replace with your embedding model here. + /// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// /// // Query the index diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 50f67b11..655f3939 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -24,13 +24,13 @@ fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { /// embedding: Vec<f64>, /// } /// -/// fn create_index(collection: mongodb::Collection<Document>, model: EmbeddingModel) { -/// let index = MongoDbVectorStore::new(collection).index( -/// model, -/// "vector_index", // <-- replace with the name of the index in your mongodb collection. -/// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. -/// ); -/// } +/// let collection: collection: mongodb::Collection<Document> = mongodb_client.collection(""); // <-- replace with your mongodb collection. +/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. +/// let index = MongoDbVectorStore::new(collection).index( +/// model, +/// "vector_index", // <-- replace with the name of the index in your mongodb collection. +/// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. +/// ); /// ``` pub struct MongoDbVectorStore<C> { collection: mongodb::Collection<C>, @@ -185,8 +185,8 @@ impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex /// definition: String, /// } /// - /// let collection: collection: mongodb::Collection<Document> = \* ... \*; // <-- replace with your mongodb collection. - /// let model: model: EmbeddingModel = \* ... \*; // <-- replace with your embedding model. + /// let collection: collection: mongodb::Collection<Document> = mongodb_client.collection(""); // <-- replace with your mongodb collection. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. /// /// let vector_store_index = MongoDbVectorStore::new(collection).index( /// model, @@ -253,8 +253,8 @@ impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex /// embedding: Vec<f64>, /// } /// - /// let collection: collection: mongodb::Collection<Document> = \* ... \*; // <-- replace with your mongodb collection. - /// let model: model: EmbeddingModel = \* ... \*; // <-- replace with your embedding model. + /// let collection: collection: mongodb::Collection<Document> = mongodb_client.collection(""); // <-- replace with your mongodb collection. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. /// let vector_store_index = MongoDbVectorStore::new(collection).index( /// model, /// "vector_index", // <-- replace with the name of the index in your mongodb collection. From b5870ce3b4bfaaf6bc986c7b5f799a9836fa6598 Mon Sep 17 00:00:00 2001 From: Christophe <cvauclair@protonmail.com> Date: Thu, 24 Oct 2024 12:15:08 -0400 Subject: [PATCH 79/91] misc: fmt --- rig-core/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 304fed94..edc5cab2 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -73,8 +73,8 @@ pub mod cli_chatbot; pub mod completion; pub mod embeddings; pub mod extractor; -pub mod one_or_many; pub(crate) mod json_utils; +pub mod one_or_many; pub mod providers; pub mod tool; pub mod vector_store; From 8c30b54a1b7328cce0f64d3b25c663e1738b5a49 Mon Sep 17 00:00:00 2001 From: Christophe <cvauclair@protonmail.com> Date: Thu, 24 Oct 2024 12:16:30 -0400 Subject: [PATCH 80/91] test: fix test --- rig-core/src/embeddings/builder.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index b1b310bb..68d9e78b 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -191,10 +191,10 @@ mod tests { async fn embed_documents( &self, - documents: Vec<String>, + documents: impl IntoIterator<Item = String> + Send, ) -> Result<Vec<crate::embeddings::Embedding>, crate::embeddings::EmbeddingError> { Ok(documents - .iter() + .into_iter() .map(|doc| Embedding { document: doc.to_string(), vec: vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], From 474695bc0f2e60b0ac806ba2ddecbdfa53c4c475 Mon Sep 17 00:00:00 2001 From: Garance Buricatu <60986356+garance-buricatu@users.noreply.github.com> Date: Fri, 15 Nov 2024 14:30:38 -0500 Subject: [PATCH 81/91] refactor(embeddings): embed trait definition (#89) * refactor: Big refactor * refactor: refactor Embed trait, fix all imports, rename files, fix macro * fix(embed trait): fix errors while testing * fix(lancedb): examples * docs: fix hyperlink * fmt: cargo fmt * PR; make requested changes * fix: change visibility of struct field * fix: failing tests --------- Co-authored-by: Christophe <cvauclair@protonmail.com> --- rig-core/Cargo.toml | 2 +- rig-core/examples/calculator_chatbot.rs | 2 +- rig-core/examples/rag.rs | 8 +- rig-core/examples/rag_dynamic_tools.rs | 2 +- rig-core/examples/vector_search.rs | 4 +- rig-core/examples/vector_search_cohere.rs | 4 +- rig-core/rig-core-derive/src/basic.rs | 6 +- .../{extract_embedding_fields.rs => embed.rs} | 75 +++---- rig-core/rig-core-derive/src/lib.rs | 12 +- rig-core/src/completion.rs | 2 +- rig-core/src/embeddings/builder.rs | 208 ++++++++++-------- rig-core/src/embeddings/embed.rs | 166 ++++++++++++++ rig-core/src/embeddings/embedding.rs | 30 ++- .../embeddings/extract_embedding_fields.rs | 163 -------------- rig-core/src/embeddings/mod.rs | 6 +- rig-core/src/embeddings/tool.rs | 31 +-- rig-core/src/lib.rs | 4 +- rig-core/src/providers/cohere.rs | 19 +- rig-core/src/providers/openai.rs | 9 +- rig-core/src/tool.rs | 8 +- rig-core/src/vector_store/in_memory_store.rs | 4 +- rig-core/src/vector_store/mod.rs | 2 +- ...bedding_fields_macro.rs => embed_macro.rs} | 68 +++--- rig-lancedb/examples/fixtures/lib.rs | 4 +- .../examples/vector_search_local_ann.rs | 1 - rig-lancedb/examples/vector_search_s3_ann.rs | 1 - rig-lancedb/src/lib.rs | 4 +- rig-mongodb/examples/vector_search_mongodb.rs | 10 +- rig-mongodb/src/lib.rs | 49 ++--- 29 files changed, 438 insertions(+), 466 deletions(-) rename rig-core/rig-core-derive/src/{extract_embedding_fields.rs => embed.rs} (51%) create mode 100644 rig-core/src/embeddings/embed.rs delete mode 100644 rig-core/src/embeddings/extract_embedding_fields.rs rename rig-core/tests/{extract_embedding_fields_macro.rs => embed_macro.rs} (64%) diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 514bdca5..c705ba06 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -35,7 +35,7 @@ tokio-test = "0.4.4" derive = ["dep:rig-derive"] [[test]] -name = "extract_embedding_fields_macro" +name = "embed_macro" required-features = ["derive"] [[example]] diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 576491d5..149b1ce4 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -247,7 +247,7 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(toolset.embedabble_tools()?)? + .documents(toolset.schemas()?)? .build() .await?; diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 9829089f..cecd20ce 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -5,14 +5,14 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, - ExtractEmbeddingFields, + Embed, }; use serde::Serialize; -// Shape of data that needs to be RAG'ed. -// A vector search needs to be performed on the definitions, so we derive the `ExtractEmbeddingFields` trait for `FakeDefinition` +// Data to be RAGged. +// A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `FakeDefinition` // and tag that field with `#[embed]`. -#[derive(ExtractEmbeddingFields, Serialize, Clone, Debug, Eq, PartialEq, Default)] +#[derive(Embed, Serialize, Clone, Debug, Eq, PartialEq, Default)] struct FakeDefinition { id: String, #[embed] diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index c140da15..459b017b 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -156,7 +156,7 @@ async fn main() -> Result<(), anyhow::Error> { .build(); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(toolset.embedabble_tools()?)? + .documents(toolset.schemas()?)? .build() .await?; diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 4777b42c..925ebca8 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -4,13 +4,13 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - ExtractEmbeddingFields, + Embed, }; use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, word: String, diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 6d966004..f3a97498 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -4,13 +4,13 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - ExtractEmbeddingFields, + Embed, }; use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, word: String, diff --git a/rig-core/rig-core-derive/src/basic.rs b/rig-core/rig-core-derive/src/basic.rs index 39b72018..b9c1e5c4 100644 --- a/rig-core/rig-core-derive/src/basic.rs +++ b/rig-core/rig-core-derive/src/basic.rs @@ -2,7 +2,7 @@ use syn::{parse_quote, Attribute, DataStruct, Meta}; use crate::EMBED; -/// Finds and returns fields with simple #[embed] attribute tags only. +/// Finds and returns fields with simple `#[embed]` attribute tags only. pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item = &syn::Field> { data_struct.fields.iter().filter(|field| { field.attrs.iter().any(|attribute| match attribute { @@ -15,11 +15,11 @@ pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item }) } -/// Adds bounds to where clause that force all fields tagged with #[embed] to implement the ExtractEmbeddingFields trait. +/// Adds bounds to where clause that force all fields tagged with `#[embed]` to implement the `Embed` trait. pub(crate) fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { let where_clause = generics.make_where_clause(); where_clause.predicates.push(parse_quote! { - #field_type: ExtractEmbeddingFields + #field_type: Embed }); } diff --git a/rig-core/rig-core-derive/src/extract_embedding_fields.rs b/rig-core/rig-core-derive/src/embed.rs similarity index 51% rename from rig-core/rig-core-derive/src/extract_embedding_fields.rs rename to rig-core/rig-core-derive/src/embed.rs index 5c21f6b2..73b89205 100644 --- a/rig-core/rig-core-derive/src/extract_embedding_fields.rs +++ b/rig-core/rig-core-derive/src/embed.rs @@ -17,8 +17,8 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu let (basic_targets, basic_target_size) = data_struct.basic(generics); let (custom_targets, custom_target_size) = data_struct.custom()?; - // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. - // ie. do not implement `ExtractEmbeddingFields` trait for the struct. + // If there are no fields tagged with `#[embed]` or `#[embed(embed_with = "...")]`, return an empty TokenStream. + // ie. do not implement `Embed` trait for the struct. if basic_target_size + custom_target_size == 0 { return Err(syn::Error::new_spanned( name, @@ -27,14 +27,14 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu } quote! { - let mut embed_targets = #basic_targets; - embed_targets.extend(#custom_targets) + #basic_targets; + #custom_targets; } } _ => { return Err(syn::Error::new_spanned( input, - "ExtractEmbeddingFields derive macro should only be used on structs", + "Embed derive macro should only be used on structs", )) } }; @@ -42,18 +42,13 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { - // Note: `ExtractEmbeddingFields` trait is imported with the macro. + // Note: `Embed` trait is imported with the macro. - impl #impl_generics ExtractEmbeddingFields for #name #ty_generics #where_clause { - type Error = rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<rig::OneOrMany<String>, Self::Error> { + impl #impl_generics Embed for #name #ty_generics #where_clause { + fn embed(&self, embedder: &mut rig::embeddings::embed::TextEmbedder) -> Result<(), rig::embeddings::embed::EmbedError> { #target_stream; - rig::OneOrMany::merge( - embed_targets.into_iter() - .collect::<Result<Vec<_>, _>>()? - ).map_err(rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError::new) + Ok(()) } } }; @@ -62,17 +57,17 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu } trait StructParser { - // Handles fields tagged with #[embed] + // Handles fields tagged with `#[embed]` fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize); - // Handles fields tagged with #[embed(embed_with = "...")] + // Handles fields tagged with `#[embed(embed_with = "...")]` fn custom(&self) -> syn::Result<(TokenStream, usize)>; } impl StructParser for DataStruct { fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize) { let embed_targets = basic_embed_fields(self) - // Iterate over every field tagged with #[embed] + // Iterate over every field tagged with `#[embed]` .map(|field| { add_struct_bounds(generics, &field.ty); @@ -84,50 +79,32 @@ impl StructParser for DataStruct { }) .collect::<Vec<_>>(); - if !embed_targets.is_empty() { - ( - quote! { - vec![#(#embed_targets.extract_embedding_fields()),*] - }, - embed_targets.len(), - ) - } else { - ( - quote! { - vec![] - }, - 0, - ) - } + ( + quote! { + #(#embed_targets.embed(embedder)?;)* + }, + embed_targets.len(), + ) } fn custom(&self) -> syn::Result<(TokenStream, usize)> { let embed_targets = custom_embed_fields(self)? - // Iterate over every field tagged with #[embed(embed_with = "...")] + // Iterate over every field tagged with `#[embed(embed_with = "...")]` .into_iter() .map(|(field, custom_func_path)| { let field_name = &field.ident; quote! { - #custom_func_path(self.#field_name.clone()) + #custom_func_path(embedder, self.#field_name.clone())?; } }) .collect::<Vec<_>>(); - Ok(if !embed_targets.is_empty() { - ( - quote! { - vec![#(#embed_targets),*] - }, - embed_targets.len(), - ) - } else { - ( - quote! { - vec![] - }, - 0, - ) - }) + Ok(( + quote! { + #(#embed_targets)* + }, + embed_targets.len(), + )) } } diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index 8ad69a65..4ce20cfa 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -4,18 +4,18 @@ use syn::{parse_macro_input, DeriveInput}; mod basic; mod custom; -mod extract_embedding_fields; +mod embed; pub(crate) const EMBED: &str = "embed"; -// https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro -// https://doc.rust-lang.org/reference/procedural-macros.html - -#[proc_macro_derive(ExtractEmbeddingFields, attributes(embed))] +/// References: +/// <https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro> +/// <https://doc.rust-lang.org/reference/procedural-macros.html> +#[proc_macro_derive(Embed, attributes(embed))] pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); - extract_embedding_fields::expand_derive_embedding(&mut input) + embed::expand_derive_embedding(&mut input) .unwrap_or_else(syn::Error::into_compile_error) .into() } diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index e766fb27..f13f316b 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -82,7 +82,7 @@ pub enum CompletionError { /// Error building the completion request #[error("RequestError: {0}")] - RequestError(#[from] Box<dyn std::error::Error + Send + Sync>), + RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>), /// Error parsing the completion response #[error("ResponseError: {0}")] diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 68d9e78b..8f0a5dd1 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -1,22 +1,24 @@ -//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. -//! Only types that implement the [ExtractEmbeddingFields] trait can be added to the [EmbeddingsBuilder]. +//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded +//! and batch generates the embeddings for each object when built. +//! Only types that implement the [Embed] trait can be added to the [EmbeddingsBuilder]. use std::{cmp::max, collections::HashMap}; -use futures::{stream, StreamExt, TryStreamExt}; +use futures::{stream, StreamExt}; use crate::{ - embeddings::{Embedding, EmbeddingError, EmbeddingModel, ExtractEmbeddingFields}, + embeddings::{Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, TextEmbedder}, OneOrMany, }; -/// Builder for creating a collection of embeddings. -pub struct EmbeddingsBuilder<M: EmbeddingModel, T: ExtractEmbeddingFields> { +/// Builder for creating a collection of embeddings from a vector of documents of type `T`. +/// Accumulate documents such that they can be embedded in a single batch to limit api calls to the provider. +pub struct EmbeddingsBuilder<M: EmbeddingModel, T: Embed> { model: M, - documents: Vec<(T, OneOrMany<String>)>, + documents: Vec<(T, Vec<String>)>, } -impl<M: EmbeddingModel, T: ExtractEmbeddingFields> EmbeddingsBuilder<M, T> { +impl<M: EmbeddingModel, T: Embed> EmbeddingsBuilder<M, T> { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { @@ -25,23 +27,23 @@ impl<M: EmbeddingModel, T: ExtractEmbeddingFields> EmbeddingsBuilder<M, T> { } } - /// Add a document that implements `ExtractEmbeddingFields` to the builder. - pub fn document(mut self, document: T) -> Result<Self, T::Error> { - let embed_targets = document.extract_embedding_fields()?; + /// Add a document that implements `Embed` to the builder. + pub fn document(mut self, document: T) -> Result<Self, EmbedError> { + let mut embedder = TextEmbedder::default(); + document.embed(&mut embedder)?; + + self.documents.push((document, embedder.texts)); - self.documents.push((document, embed_targets)); Ok(self) } - /// Add many documents that implement `ExtractEmbeddingFields` to the builder. - pub fn documents(mut self, documents: Vec<T>) -> Result<Self, T::Error> { - for doc in documents.into_iter() { - let embed_targets = doc.extract_embedding_fields()?; - - self.documents.push((doc, embed_targets)); - } + /// Add many documents that implement `Embed` to the builder. + pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> { + let builder = documents + .into_iter() + .try_fold(self, |builder, doc| builder.document(doc))?; - Ok(self) + Ok(builder) } } @@ -53,13 +55,13 @@ impl<M: EmbeddingModel, T: ExtractEmbeddingFields> EmbeddingsBuilder<M, T> { /// embeddings::EmbeddingsBuilder, /// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, /// vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, -/// ExtractEmbeddingFields, +/// Embed, /// }; /// use serde::{Deserialize, Serialize}; /// /// // Shape of data that needs to be RAG'ed. /// // The definition field will be used to generate embeddings. -/// #[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +/// #[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] /// struct FakeDefinition { /// id: String, /// word: String, @@ -103,45 +105,38 @@ impl<M: EmbeddingModel, T: ExtractEmbeddingFields> EmbeddingsBuilder<M, T> { /// .build() /// .await?; /// ``` -impl<M: EmbeddingModel, T: ExtractEmbeddingFields + Send + Sync + Clone> EmbeddingsBuilder<M, T> { +impl<M: EmbeddingModel, T: Embed + Send> EmbeddingsBuilder<M, T> { /// Generate embeddings for all documents in the builder. /// Returns a vector of tuples, where the first element is the document and the second element is the embeddings (either one embedding or many). pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> { - // Use this for reference later to merge a document back with its embeddings. - let documents_map = self - .documents - .clone() - .into_iter() - .enumerate() - .map(|(id, (document, _))| (id, document)) - .collect::<HashMap<_, _>>(); - - let embeddings = stream::iter(self.documents.iter().enumerate()) - // Merge the embedding targets of each document into a single list of embedding targets. - .flat_map(|(i, (_, embed_targets))| { - stream::iter( - embed_targets - .clone() - .into_iter() - .map(move |target| (i, target)), - ) - }) - // Chunk them into N (the emebdding API limit per request). + use stream::TryStreamExt; + + // Store the documents and their texts in a HashMap for easy access. + let mut docs = HashMap::new(); + let mut texts = HashMap::new(); + + // Iterate over all documents in the builder and insert their docs and texts into the lookup stores. + for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() { + docs.insert(i, doc); + texts.insert(i, doc_texts); + } + + // Compute the embeddings. + let mut embeddings = stream::iter(texts.into_iter()) + // Merge the texts of each document into a single list of texts. + .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text)))) + // Chunk them into batches. Each batch size is at most the embedding API limit per request. .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings for a chunk at a time. - .map(|docs| async { - let (document_indices, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - - Ok::<_, EmbeddingError>( - document_indices - .into_iter() - .zip(self.model.embed_documents(embed_targets).await?.into_iter()) - .collect::<Vec<_>>(), - ) + // Generate the embeddings for each batch. + .map(|text| async { + let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip(); + + let embeddings = self.model.embed_texts(docs).await?; + Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>()) }) - .boxed() // Parallelize the embeddings generation over 10 concurrent requests .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) + // Collect the embeddings into a HashMap. .try_fold( HashMap::new(), |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move { @@ -154,27 +149,26 @@ impl<M: EmbeddingModel, T: ExtractEmbeddingFields + Send + Sync + Clone> Embeddi Ok(acc) }, ) - .await? - .iter() - .fold(vec![], |mut acc, (i, embeddings_vec)| { - acc.push(( - documents_map.get(i).cloned().unwrap(), - embeddings_vec.clone(), - )); - acc - }); - - Ok(embeddings) + .await?; + + // Merge the embeddings with their respective documents + Ok(docs + .into_iter() + .map(|(i, doc)| { + ( + doc, + embeddings.remove(&i).expect("Document should be present"), + ) + }) + .collect()) } } #[cfg(test)] mod tests { use crate::{ - embeddings::{ - extract_embedding_fields::ExtractEmbeddingFieldsError, Embedding, EmbeddingModel, - }, - ExtractEmbeddingFields, + embeddings::{embed::EmbedError, Embedding, EmbeddingModel, TextEmbedder}, + Embed, }; use super::EmbeddingsBuilder; @@ -189,7 +183,7 @@ mod tests { 10 } - async fn embed_documents( + async fn embed_texts( &self, documents: impl IntoIterator<Item = String> + Send, ) -> Result<Vec<crate::embeddings::Embedding>, crate::embeddings::EmbeddingError> { @@ -209,16 +203,16 @@ mod tests { definitions: Vec<String>, } - impl ExtractEmbeddingFields for FakeDefinition { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<crate::OneOrMany<String>, Self::Error> { - crate::OneOrMany::many(self.definitions.clone()) - .map_err(ExtractEmbeddingFieldsError::new) + impl Embed for FakeDefinition { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + for definition in &self.definitions { + embedder.embed(definition.clone()); + } + Ok(()) } } - fn fake_definitions() -> Vec<FakeDefinition> { + fn fake_definitions_multiple_text() -> Vec<FakeDefinition> { vec![ FakeDefinition { id: "doc0".to_string(), @@ -237,7 +231,7 @@ mod tests { ] } - fn fake_definitions_2() -> Vec<FakeDefinition> { + fn fake_definitions_multiple_text_2() -> Vec<FakeDefinition> { vec![ FakeDefinition { id: "doc2".to_string(), @@ -256,15 +250,14 @@ mod tests { definition: String, } - impl ExtractEmbeddingFields for FakeDefinitionSingle { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<crate::OneOrMany<String>, Self::Error> { - Ok(crate::OneOrMany::one(self.definition.clone())) + impl Embed for FakeDefinitionSingle { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.definition.clone()); + Ok(()) } } - fn fake_definitions_single() -> Vec<FakeDefinitionSingle> { + fn fake_definitions_single_text() -> Vec<FakeDefinitionSingle> { vec![ FakeDefinitionSingle { id: "doc0".to_string(), @@ -278,8 +271,8 @@ mod tests { } #[tokio::test] - async fn test_build_many() { - let fake_definitions = fake_definitions(); + async fn test_build_multiple_text() { + let fake_definitions = fake_definitions_multiple_text(); let fake_model = FakeModel; let mut result = EmbeddingsBuilder::new(fake_model) @@ -312,8 +305,8 @@ mod tests { } #[tokio::test] - async fn test_build_single() { - let fake_definitions = fake_definitions_single(); + async fn test_build_single_text() { + let fake_definitions = fake_definitions_single_text(); let fake_model = FakeModel; let mut result = EmbeddingsBuilder::new(fake_model) @@ -346,9 +339,9 @@ mod tests { } #[tokio::test] - async fn test_build_many_and_single() { - let fake_definitions = fake_definitions(); - let fake_definitions_single = fake_definitions_2(); + async fn test_build_multiple_and_single_text() { + let fake_definitions = fake_definitions_multiple_text(); + let fake_definitions_single = fake_definitions_multiple_text_2(); let fake_model = FakeModel; let mut result = EmbeddingsBuilder::new(fake_model) @@ -381,4 +374,37 @@ mod tests { "Another fake definitions".to_string() ) } + + #[tokio::test] + async fn test_build_string() { + let bindings = fake_definitions_multiple_text(); + let fake_definitions = bindings.iter().map(|def| def.definitions.clone()); + + let fake_model = FakeModel; + let mut result = EmbeddingsBuilder::new(fake_model) + .documents(fake_definitions) + .unwrap() + .build() + .await + .unwrap(); + + result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| { + fake_definition_1.cmp(&fake_definition_2) + }); + + assert_eq!(result.len(), 2); + + let first_definition = &result[0]; + assert_eq!(first_definition.1.len(), 2); + assert_eq!( + first_definition.1.first().document, + "A green alien that lives on cold planets.".to_string() + ); + + let second_definition = &result[1]; + assert_eq!(second_definition.1.len(), 2); + assert_eq!( + second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + ) + } } diff --git a/rig-core/src/embeddings/embed.rs b/rig-core/src/embeddings/embed.rs new file mode 100644 index 00000000..2fd26c2a --- /dev/null +++ b/rig-core/src/embeddings/embed.rs @@ -0,0 +1,166 @@ +//! The module defines the [Embed] trait, which must be implemented for types that can be embedded by the `EmbeddingsBuilder`. + +/// Error type used for when the `embed` method fo the `Embed` trait fails. +/// Used by default implementations of `Embed` for common types. +#[derive(Debug, thiserror::Error)] +#[error("{0}")] +pub struct EmbedError(#[from] Box<dyn std::error::Error + Send + Sync>); + +impl EmbedError { + pub fn new<E: std::error::Error + Send + Sync + 'static>(error: E) -> Self { + EmbedError(Box::new(error)) + } +} + +/// Derive this trait for objects that need to be converted to vector embeddings. +/// The `embed` method accumulates string values that need to be embedded by adding them to the `TextEmbedder`. +/// If an error occurs, the method should return `EmbedError`. +/// # Example +/// ```rust +/// use std::env; +/// +/// use serde::{Deserialize, Serialize}; +/// use rig::Embed; +/// +/// struct FakeDefinition { +/// id: String, +/// word: String, +/// definitions: String, +/// } +/// +/// impl Embed for FakeDefinition { +/// fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { +/// // Embeddings only need to be generated for `definition` field. +/// // Split the definitions by comma and collect them into a vector of strings. +/// // That way, different embeddings can be generated for each definition in the definitions string. +/// self.definitions +/// .split(",") +/// .for_each(|s| { +/// embedder.embed(s.to_string()); +/// }); +/// +/// Ok(()) +/// } +/// } +/// ``` +pub trait Embed { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError>; +} + +/// Accumulates string values that need to be embedded. +/// Used by the `Embed` trait. +#[derive(Default)] +pub struct TextEmbedder { + pub(crate) texts: Vec<String>, +} + +impl TextEmbedder { + pub fn embed(&mut self, text: String) { + self.texts.push(text); + } +} + +/// Client-side function to convert an object that implements the `Embed` trait to a vector of strings. +/// Similar to `serde`'s `serde_json::to_string()` function +pub fn to_texts(item: impl Embed) -> Result<Vec<String>, EmbedError> { + let mut embedder = TextEmbedder::default(); + item.embed(&mut embedder)?; + Ok(embedder.texts) +} + +// ================================================================ +// Implementations of Embed for common types +// ================================================================ + +impl Embed for String { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.clone()); + Ok(()) + } +} + +impl Embed for &str { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i8 { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i16 { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i32 { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i64 { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i128 { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for f32 { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for f64 { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for bool { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for char { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for serde_json::Value { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(serde_json::to_string(self).map_err(EmbedError::new)?); + Ok(()) + } +} + +impl<T: Embed> Embed for Vec<T> { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + for item in self { + item.embed(embedder).map_err(EmbedError::new)?; + } + Ok(()) + } +} diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index 47820a81..d033f57e 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -22,7 +22,7 @@ pub enum EmbeddingError { /// Error processing the document for embedding #[error("DocumentError: {0}")] - DocumentError(String), + DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>), /// Error parsing the completion response #[error("ResponseError: {0}")] @@ -41,29 +41,25 @@ pub trait EmbeddingModel: Clone + Sync + Send { /// The number of dimensions in the embedding vector. fn ndims(&self) -> usize; - /// Embed a single document - fn embed_document( + /// Embed multiple text documents in a single request + fn embed_texts( + &self, + documents: impl IntoIterator<Item = String> + Send, + ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send; + + /// Embed a single text document. + fn embed_text( &self, document: &str, - ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send - where - Self: Sync, - { + ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send { async { Ok(self - .embed_documents(vec![document.to_string()]) + .embed_texts(vec![document.to_string()]) .await? - .first() - .cloned() - .expect("One embedding should be present")) + .pop() + .expect("There should be at least one embedding")) } } - - /// Embed multiple documents in a single request - fn embed_documents( - &self, - documents: impl IntoIterator<Item = String> + Send, - ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send; } /// Struct that holds a single document and its embedding. diff --git a/rig-core/src/embeddings/extract_embedding_fields.rs b/rig-core/src/embeddings/extract_embedding_fields.rs deleted file mode 100644 index e62d43c6..00000000 --- a/rig-core/src/embeddings/extract_embedding_fields.rs +++ /dev/null @@ -1,163 +0,0 @@ -//! The module defines the [ExtractEmbeddingFields] trait, which must be implemented for types that can be embedded. - -use crate::one_or_many::OneOrMany; - -/// Error type used for when the `extract_embedding_fields` method fails. -/// Used by default implementations of `ExtractEmbeddingFields` for common types. -#[derive(Debug, thiserror::Error)] -#[error("{0}")] -pub struct ExtractEmbeddingFieldsError(#[from] Box<dyn std::error::Error + Send + Sync>); - -impl ExtractEmbeddingFieldsError { - pub fn new<E: std::error::Error + Send + Sync + 'static>(error: E) -> Self { - ExtractEmbeddingFieldsError(Box::new(error)) - } -} - -/// Derive this trait for structs whose fields need to be converted to vector embeddings. -/// The `extract_embedding_fields` method returns a `OneOrMany<String>`. This function extracts the fields that need to be embedded and returns them as a list of strings. -/// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. -/// # Example -/// ```rust -/// use std::env; -/// -/// use serde::{Deserialize, Serialize}; -/// use rig::{OneOrMany, EmptyListError, ExtractEmbeddingFields}; -/// -/// struct FakeDefinition { -/// id: String, -/// word: String, -/// definitions: String, -/// } -/// -/// let fake_definition = FakeDefinition { -/// id: "doc1".to_string(), -/// word: "rock".to_string(), -/// definitions: "the solid mineral material forming part of the surface of the earth, a precious stone".to_string() -/// }; -/// -/// impl ExtractEmbeddingFields for FakeDefinition { -/// type Error = EmptyListError; -/// -/// fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { -/// // Embeddings only need to be generated for `definition` field. -/// // Split the definitions by comma and collect them into a vector of strings. -/// // That way, different embeddings can be generated for each definition in the definitions string. -/// let definitions = self.definitions.split(",").collect::<Vec<_>>().into_iter().map(|s| s.to_string()).collect(); -/// -/// OneOrMany::many(definitions) -/// } -/// } -/// ``` -pub trait ExtractEmbeddingFields { - type Error: std::error::Error + Sync + Send + 'static; - - fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error>; -} - -// ================================================================ -// Implementations of ExtractEmbeddingFields for common types -// ================================================================ -impl ExtractEmbeddingFields for String { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.clone())) - } -} - -impl ExtractEmbeddingFields for i8 { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for i16 { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for i32 { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for i64 { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for i128 { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for f32 { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for f64 { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for bool { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for char { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one(self.to_string())) - } -} - -impl ExtractEmbeddingFields for serde_json::Value { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { - Ok(OneOrMany::one( - serde_json::to_string(self).map_err(ExtractEmbeddingFieldsError::new)?, - )) - } -} - -impl<T: ExtractEmbeddingFields> ExtractEmbeddingFields for Vec<T> { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { - let items = self - .iter() - .map(|item| item.extract_embedding_fields()) - .collect::<Result<Vec<_>, _>>() - .map_err(ExtractEmbeddingFieldsError::new)?; - - OneOrMany::merge(items).map_err(ExtractEmbeddingFieldsError::new) - } -} diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 37323cf5..1ae16436 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -4,11 +4,11 @@ //! and document similarity. pub mod builder; +pub mod embed; pub mod embedding; -pub mod extract_embedding_fields; pub mod tool; pub use builder::EmbeddingsBuilder; +pub use embed::{to_texts, Embed, EmbedError, TextEmbedder}; pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; -pub use extract_embedding_fields::ExtractEmbeddingFields; -pub use tool::EmbeddableTool; +pub use tool::ToolSchema; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index 7550038f..bcea7c7d 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -1,31 +1,32 @@ -use crate::{tool::ToolEmbeddingDyn, ExtractEmbeddingFields, OneOrMany}; +use crate::{tool::ToolEmbeddingDyn, Embed}; use serde::Serialize; -use super::extract_embedding_fields::ExtractEmbeddingFieldsError; +use super::embed::EmbedError; /// Used by EmbeddingsBuilder to embed anything that implements ToolEmbedding. #[derive(Clone, Serialize, Default, Eq, PartialEq)] -pub struct EmbeddableTool { +pub struct ToolSchema { pub name: String, pub context: serde_json::Value, pub embedding_docs: Vec<String>, } -impl ExtractEmbeddingFields for EmbeddableTool { - type Error = ExtractEmbeddingFieldsError; - - fn extract_embedding_fields(&self) -> Result<OneOrMany<String>, Self::Error> { - OneOrMany::many(self.embedding_docs.clone()).map_err(ExtractEmbeddingFieldsError::new) +impl Embed for ToolSchema { + fn embed(&self, embedder: &mut super::embed::TextEmbedder) -> Result<(), EmbedError> { + for doc in &self.embedding_docs { + embedder.embed(doc.clone()); + } + Ok(()) } } -impl EmbeddableTool { - /// Convert item that implements ToolEmbeddingDyn to an EmbeddableTool. +impl ToolSchema { + /// Convert item that implements ToolEmbeddingDyn to an ToolSchema. /// # Example /// ```rust /// use rig::{ /// completion::ToolDefinition, - /// embeddings::EmbeddableTool, + /// embeddings::ToolSchema, /// tool::{Tool, ToolEmbedding, ToolEmbeddingDyn}, /// }; /// use serde_json::json; @@ -76,15 +77,15 @@ impl EmbeddableTool { /// fn context(&self) -> Self::Context {} /// } /// - /// let tool = EmbeddableTool::try_from(&Nothing).unwrap(); + /// let tool = ToolSchema::try_from(&Nothing).unwrap(); /// /// assert_eq!(tool.name, "nothing".to_string()); /// assert_eq!(tool.embedding_docs, vec!["Do nothing.".to_string()]); /// ``` - pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result<Self, ExtractEmbeddingFieldsError> { - Ok(EmbeddableTool { + pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result<Self, EmbedError> { + Ok(ToolSchema { name: tool.name(), - context: tool.context().map_err(ExtractEmbeddingFieldsError::new)?, + context: tool.context().map_err(EmbedError::new)?, embedding_docs: tool.embedding_docs(), }) } diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index edc5cab2..6c5db7ab 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -80,8 +80,8 @@ pub mod tool; pub mod vector_store; // Re-export commonly used types and traits -pub use embeddings::ExtractEmbeddingFields; +pub use embeddings::{to_texts, Embed}; pub use one_or_many::{EmptyListError, OneOrMany}; #[cfg(feature = "derive")] -pub use rig_derive::ExtractEmbeddingFields; +pub use rig_derive::Embed; diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index ef449dd6..4299684e 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -15,7 +15,7 @@ use crate::{ completion::{self, CompletionError}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, ExtractEmbeddingFields, + json_utils, Embed, }; use schemars::JsonSchema; @@ -92,7 +92,7 @@ impl Client { EmbeddingModel::new(self.clone(), model, input_type, ndims) } - pub fn embeddings<D: ExtractEmbeddingFields>( + pub fn embeddings<D: Embed>( &self, model: &str, input_type: &str, @@ -201,7 +201,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { self.ndims } - async fn embed_documents( + async fn embed_texts( &self, documents: impl IntoIterator<Item = String>, ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> { @@ -222,11 +222,14 @@ impl embeddings::EmbeddingModel for EmbeddingModel { match response.json::<ApiResponse<EmbeddingResponse>>().await? { ApiResponse::Ok(response) => { if response.embeddings.len() != documents.len() { - return Err(EmbeddingError::DocumentError(format!( - "Expected {} embeddings, got {}", - documents.len(), - response.embeddings.len() - ))); + return Err(EmbeddingError::DocumentError( + format!( + "Expected {} embeddings, got {}", + documents.len(), + response.embeddings.len() + ) + .into(), + )); } Ok(response diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index ab1517d8..828f27bc 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -13,7 +13,7 @@ use crate::{ completion::{self, CompletionError, CompletionRequest}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, ExtractEmbeddingFields, + json_utils, Embed, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -121,10 +121,7 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings<D: ExtractEmbeddingFields>( - &self, - model: &str, - ) -> EmbeddingsBuilder<EmbeddingModel, D> { + pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> { EmbeddingsBuilder::new(self.embedding_model(model)) } @@ -242,7 +239,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { self.ndims } - async fn embed_documents( + async fn embed_texts( &self, documents: impl IntoIterator<Item = String>, ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> { diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 528faba5..198be04c 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::{ completion::{self, ToolDefinition}, - embeddings::{extract_embedding_fields::ExtractEmbeddingFieldsError, tool::EmbeddableTool}, + embeddings::{embed::EmbedError, tool::ToolSchema}, }; #[derive(Debug, thiserror::Error)] @@ -327,15 +327,15 @@ impl ToolSet { Ok(docs) } - /// Convert tools in self to objects of type EmbeddableTool. + /// Convert tools in self to objects of type ToolSchema. /// This is necessary because when adding tools to the EmbeddingBuilder because all /// documents added to the builder must all be of the same type. - pub fn embedabble_tools(&self) -> Result<Vec<EmbeddableTool>, ExtractEmbeddingFieldsError> { + pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> { self.tools .values() .filter_map(|tool_type| { if let ToolType::Embedding(tool) = tool_type { - Some(EmbeddableTool::try_from(&**tool)) + Some(ToolSchema::try_from(&**tool)) } else { None } diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index f4f067fe..4519f45a 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -171,7 +171,7 @@ impl<M: EmbeddingModel + Sync, D: Serialize + Sync + Send + Eq> VectorStoreIndex query: &str, n: usize, ) -> Result<Vec<(f64, String, T)>, VectorStoreError> { - let prompt_embedding = &self.model.embed_document(query).await?; + let prompt_embedding = &self.model.embed_text(query).await?; let docs = self.store.vector_search(prompt_embedding, n); @@ -192,7 +192,7 @@ impl<M: EmbeddingModel + Sync, D: Serialize + Sync + Send + Eq> VectorStoreIndex query: &str, n: usize, ) -> Result<Vec<(f64, String)>, VectorStoreError> { - let prompt_embedding = &self.model.embed_document(query).await?; + let prompt_embedding = &self.model.embed_text(query).await?; let docs = self.store.vector_search(prompt_embedding, n); diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 044d8c2a..b2b8c93e 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -16,7 +16,7 @@ pub enum VectorStoreError { JsonError(#[from] serde_json::Error), #[error("Datastore error: {0}")] - DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync>), + DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>), #[error("Missing Id: {0}")] MissingIdError(String), diff --git a/rig-core/tests/extract_embedding_fields_macro.rs b/rig-core/tests/embed_macro.rs similarity index 64% rename from rig-core/tests/extract_embedding_fields_macro.rs rename to rig-core/tests/embed_macro.rs index 5f8891ff..778b70bd 100644 --- a/rig-core/tests/extract_embedding_fields_macro.rs +++ b/rig-core/tests/embed_macro.rs @@ -1,14 +1,16 @@ -use rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError; -use rig::{ExtractEmbeddingFields, OneOrMany}; +use rig::{ + embeddings::{embed::EmbedError, TextEmbedder}, + to_texts, Embed, +}; use serde::Serialize; -fn serialize(definition: Definition) -> Result<OneOrMany<String>, ExtractEmbeddingFieldsError> { - Ok(OneOrMany::one( - serde_json::to_string(&definition).map_err(ExtractEmbeddingFieldsError::new)?, - )) +fn serialize(embedder: &mut TextEmbedder, definition: Definition) -> Result<(), EmbedError> { + embedder.embed(serde_json::to_string(&definition).map_err(EmbedError::new)?); + + Ok(()) } -#[derive(ExtractEmbeddingFields)] +#[derive(Embed)] struct FakeDefinition { id: String, word: String, @@ -41,15 +43,13 @@ fn test_custom_embed() { ); assert_eq!( - fake_definition.extract_embedding_fields().unwrap(), - OneOrMany::one( - "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() - ) + to_texts(fake_definition).unwrap().first().unwrap().clone(), + "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() ) } -#[derive(ExtractEmbeddingFields)] +#[derive(Embed)] struct FakeDefinition2 { id: String, #[embed] @@ -75,18 +75,17 @@ fn test_custom_and_basic_embed() { fake_definition.id, fake_definition.word ); - assert_eq!( - fake_definition.extract_embedding_fields().unwrap().first(), - "house".to_string() - ); + let texts = to_texts(fake_definition).unwrap(); + + assert_eq!(texts.first().unwrap().clone(), "house".to_string()); assert_eq!( - fake_definition.extract_embedding_fields().unwrap().rest(), - vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] + texts.last().unwrap().clone(), + "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() ) } -#[derive(ExtractEmbeddingFields)] +#[derive(Embed)] struct FakeDefinition3 { id: String, word: String, @@ -109,12 +108,12 @@ fn test_single_embed() { ); assert_eq!( - fake_definition.extract_embedding_fields().unwrap(), - OneOrMany::one(definition) + to_texts(fake_definition).unwrap().first().unwrap().clone(), + definition ) } -#[derive(ExtractEmbeddingFields)] +#[derive(Embed)] struct Company { id: String, company: String, @@ -132,28 +131,18 @@ fn test_multiple_embed_strings() { println!("Company: {}, {}", company.id, company.company); - let result = company.extract_embedding_fields().unwrap(); - assert_eq!( - result, - OneOrMany::many(vec![ + to_texts(company).unwrap(), + vec![ "25".to_string(), "30".to_string(), "35".to_string(), "40".to_string() - ]) - .unwrap() + ] ); - - assert_eq!(result.first(), "25".to_string()); - - assert_eq!( - result.rest(), - vec!["30".to_string(), "35".to_string(), "40".to_string()] - ) } -#[derive(ExtractEmbeddingFields)] +#[derive(Embed)] struct Company2 { id: String, #[embed] @@ -173,14 +162,13 @@ fn test_multiple_embed_tags() { println!("Company: {}", company.id); assert_eq!( - company.extract_embedding_fields().unwrap(), - OneOrMany::many(vec![ + to_texts(company).unwrap(), + vec![ "Google".to_string(), "25".to_string(), "30".to_string(), "35".to_string(), "40".to_string() - ]) - .unwrap() + ] ); } diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 1a9089a5..954494e5 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -3,10 +3,10 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; use rig::embeddings::Embedding; -use rig::{ExtractEmbeddingFields, OneOrMany}; +use rig::{Embed, OneOrMany}; use serde::Deserialize; -#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug)] +#[derive(Embed, Clone, Deserialize, Debug)] pub struct FakeDefinition { pub id: String, #[embed] diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 7ffd6b12..84679e3f 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -35,7 +35,6 @@ async fn main() -> Result<(), anyhow::Error> { id: format!("doc{}", i), definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() }) - .collect(), )? .build() .await?; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 824deda0..160dfa10 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -41,7 +41,6 @@ async fn main() -> Result<(), anyhow::Error> { id: format!("doc{}", i), definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() }) - .collect(), )? .build() .await?; diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 2eea2357..0d06b495 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -195,7 +195,7 @@ impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorIndex<M> query: &str, n: usize, ) -> Result<Vec<(f64, String, T)>, VectorStoreError> { - let prompt_embedding = self.model.embed_document(query).await?; + let prompt_embedding = self.model.embed_text(query).await?; let query = self .table @@ -241,7 +241,7 @@ impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorIndex<M> query: &str, n: usize, ) -> Result<Vec<(f64, String)>, VectorStoreError> { - let prompt_embedding = self.model.embed_document(query).await?; + let prompt_embedding = self.model.embed_text(query).await?; let query = self .table diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index cc16d4bc..b0867064 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -3,15 +3,14 @@ use rig::providers::openai::TEXT_EMBEDDING_ADA_002; use serde::{Deserialize, Serialize}; use std::env; -use rig::ExtractEmbeddingFields; use rig::{ - embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, + embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, Embed, }; -use rig_mongodb::{MongoDbVectorStore, SearchParams}; +use rig_mongodb::{MongoDbVectorIndex, SearchParams}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug)] +#[derive(Embed, Clone, Deserialize, Debug)] struct FakeDefinition { #[serde(rename = "_id")] id: String, @@ -97,7 +96,8 @@ async fn main() -> Result<(), anyhow::Error> { // Create a vector index on our vector store. // Note: a vector index called "vector_index" must exist on the MongoDB collection you are querying. // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = MongoDbVectorStore::new(collection).index( + let index = MongoDbVectorIndex::new( + collection, model, "vector_index", SearchParams::new("embedding"), diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 655f3939..3631cc8d 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -11,9 +11,10 @@ fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } +/// A vector index for a MongoDB collection. /// # Example /// ``` -/// use rig_mongodb::{MongoDbVectorStore, SearchParams}; +/// use rig_mongodb::{MongoDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// /// #[derive(serde::Serialize, Debug)] @@ -26,37 +27,13 @@ fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { /// /// let collection: collection: mongodb::Collection<Document> = mongodb_client.collection(""); // <-- replace with your mongodb collection. /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. -/// let index = MongoDbVectorStore::new(collection).index( +/// let index = MongoDbVectorIndex::new( +/// collection, /// model, /// "vector_index", // <-- replace with the name of the index in your mongodb collection. /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. /// ); /// ``` -pub struct MongoDbVectorStore<C> { - collection: mongodb::Collection<C>, -} - -impl<C> MongoDbVectorStore<C> { - /// Create a new `MongoDbVectorStore` from a MongoDB collection. - pub fn new(collection: mongodb::Collection<C>) -> Self { - Self { collection } - } - - /// Create a new `MongoDbVectorIndex` from an existing `MongoDbVectorStore`. - /// - /// The index (of type "vector") must already exist for the MongoDB collection. - /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes. - pub fn index<M: EmbeddingModel>( - &self, - model: M, - index_name: &str, - search_params: SearchParams, - ) -> MongoDbVectorIndex<M, C> { - MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params) - } -} - -/// A vector index for a MongoDB collection. pub struct MongoDbVectorIndex<M: EmbeddingModel, C> { collection: mongodb::Collection<C>, model: M, @@ -100,6 +77,10 @@ impl<M: EmbeddingModel, C> MongoDbVectorIndex<M, C> { } impl<M: EmbeddingModel, C> MongoDbVectorIndex<M, C> { + /// Create a new `MongoDbVectorIndex`. + /// + /// The index (of type "vector") must already exist for the MongoDB collection. + /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes. pub fn new( collection: mongodb::Collection<C>, model: M, @@ -167,7 +148,7 @@ impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex /// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`. /// # Example /// ``` - /// use rig_mongodb::{MongoDbVectorStore, SearchParams}; + /// use rig_mongodb::{MongoDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// /// #[derive(serde::Serialize, Debug)] @@ -188,7 +169,8 @@ impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex /// let collection: collection: mongodb::Collection<Document> = mongodb_client.collection(""); // <-- replace with your mongodb collection. /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. /// - /// let vector_store_index = MongoDbVectorStore::new(collection).index( + /// let vector_store_index = MongoDbVectorIndex::new( + /// collection, /// model, /// "vector_index", // <-- replace with the name of the index in your mongodb collection. /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. @@ -204,7 +186,7 @@ impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex query: &str, n: usize, ) -> Result<Vec<(f64, String, T)>, VectorStoreError> { - let prompt_embedding = self.model.embed_document(query).await?; + let prompt_embedding = self.model.embed_text(query).await?; let mut cursor = self .collection @@ -242,7 +224,7 @@ impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`. /// # Example /// ``` - /// use rig_mongodb::{MongoDbVectorStore, SearchParams}; + /// use rig_mongodb::{MongoDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// /// #[derive(serde::Serialize, Debug)] @@ -255,7 +237,8 @@ impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex /// /// let collection: collection: mongodb::Collection<Document> = mongodb_client.collection(""); // <-- replace with your mongodb collection. /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. - /// let vector_store_index = MongoDbVectorStore::new(collection).index( + /// let vector_store_index = MongoDbVectorIndex::new( + /// collection, /// model, /// "vector_index", // <-- replace with the name of the index in your mongodb collection. /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. @@ -271,7 +254,7 @@ impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex query: &str, n: usize, ) -> Result<Vec<(f64, String)>, VectorStoreError> { - let prompt_embedding = self.model.embed_document(query).await?; + let prompt_embedding = self.model.embed_text(query).await?; let mut cursor = self .collection From 1336633f36f6691d664f1e1129c374caeb316903 Mon Sep 17 00:00:00 2001 From: Garance Buricatu <garance.mary@gmail.com> Date: Fri, 22 Nov 2024 22:47:04 -0500 Subject: [PATCH 82/91] fix/docs: fix erros from merge, cleanup embeddings docstrings --- Cargo.lock | 934 +++++++++++------- rig-core/Cargo.toml | 8 + rig-core/examples/gemini_embeddings.rs | 18 +- rig-core/examples/xai_embeddings.rs | 18 +- rig-core/src/embeddings/embed.rs | 22 +- rig-core/src/embeddings/embedding.rs | 4 +- rig-core/src/embeddings/tool.rs | 2 + rig-core/src/lib.rs | 2 +- rig-core/src/providers/gemini/client.rs | 6 +- rig-core/src/providers/gemini/embedding.rs | 2 +- rig-core/src/providers/xai/client.rs | 6 +- rig-core/src/providers/xai/embedding.rs | 2 +- rig-core/src/vector_store/in_memory_store.rs | 6 +- rig-lancedb/src/lib.rs | 4 +- rig-mongodb/examples/vector_search_mongodb.rs | 8 +- rig-mongodb/src/lib.rs | 15 +- rig-neo4j/Cargo.toml | 4 + rig-neo4j/examples/vector_search_simple.rs | 29 +- rig-neo4j/src/vector_index.rs | 4 +- rig-qdrant/Cargo.toml | 4 + rig-qdrant/examples/qdrant_vector_search.rs | 29 +- rig-qdrant/src/lib.rs | 2 +- 22 files changed, 715 insertions(+), 414 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4fc6d678..2d778dde 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 4 +version = 3 [[package]] name = "addr2line" @@ -42,9 +42,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.18" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" [[package]] name = "android-tzdata" @@ -63,15 +63,15 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.8" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anyhow" -version = "1.0.89" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" [[package]] name = "arc-swap" @@ -361,7 +361,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -383,7 +383,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -394,7 +394,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -426,9 +426,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.5.8" +version = "1.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7198e6f03240fdceba36656d8be440297b6b82270325908c7381f37d826a74f6" +checksum = "9b49afaa341e8dd8577e1a2200468f98956d6eda50bcf4a53246cc00174ba924" dependencies = [ "aws-credential-types", "aws-runtime", @@ -443,7 +443,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "fastrand 2.1.1", + "fastrand 2.2.0", "hex", "http 0.2.12", "ring", @@ -481,7 +481,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "fastrand 2.1.1", + "fastrand 2.2.0", "http 0.2.12", "http-body 0.4.6", "once_cell", @@ -493,9 +493,9 @@ dependencies = [ [[package]] name = "aws-sdk-dynamodb" -version = "1.49.0" +version = "1.54.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab0ade000608877169533a54326badd6b5a707d2faf876cfc3976a7f9d7e5329" +checksum = "8efdda6a491bb4640d35b99b0a4b93f75ce7d6e3a1937c3e902d3cb23d0a179c" dependencies = [ "aws-credential-types", "aws-runtime", @@ -507,7 +507,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "fastrand 2.1.1", + "fastrand 2.2.0", "http 0.2.12", "once_cell", "regex-lite", @@ -516,9 +516,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.45.0" +version = "1.49.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e33ae899566f3d395cbf42858e433930682cc9c1889fa89318896082fef45efb" +checksum = "09677244a9da92172c8dc60109b4a9658597d4d298b188dd0018b6a66b410ca4" dependencies = [ "aws-credential-types", "aws-runtime", @@ -538,9 +538,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.46.0" +version = "1.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f39c09e199ebd96b9f860b0fce4b6625f211e064ad7c8693b72ecf7ef03881e0" +checksum = "81fea2f3a8bb3bd10932ae7ad59cc59f65f270fc9183a7e91f501dc5efbef7ee" dependencies = [ "aws-credential-types", "aws-runtime", @@ -560,9 +560,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.45.0" +version = "1.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d95f93a98130389eb6233b9d615249e543f6c24a68ca1f109af9ca5164a8765" +checksum = "6ada54e5f26ac246dc79727def52f7f8ed38915cb47781e2a72213957dc3a7d5" dependencies = [ "aws-credential-types", "aws-runtime", @@ -583,9 +583,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.2.4" +version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc8db6904450bafe7473c6ca9123f88cc11089e41a025408f992db4e22d3be68" +checksum = "5619742a0d8f253be760bfbb8e8e8368c69e3587e4637af5754e488a611499b1" dependencies = [ "aws-credential-types", "aws-smithy-http", @@ -656,22 +656,22 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.2" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a065c0fe6fdbdf9f11817eb68582b2ab4aff9e9c39e986ae48f7ec576c6322db" +checksum = "be28bd063fa91fd871d131fc8b68d7cd4c5fa0869bea68daca50dcb1cbd76be2" dependencies = [ "aws-smithy-async", "aws-smithy-http", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", - "fastrand 2.1.1", + "fastrand 2.2.0", "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", "http-body 1.0.1", "httparse", - "hyper 0.14.30", + "hyper 0.14.31", "hyper-rustls 0.24.2", "once_cell", "pin-project-lite", @@ -683,9 +683,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.7.2" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e086682a53d3aa241192aa110fa8dfce98f2f5ac2ead0de84d41582c7e8fdb96" +checksum = "92165296a47a812b267b4f41032ff8069ab7ff783696d217f0994a0d7ab585cd" dependencies = [ "aws-smithy-async", "aws-smithy-types", @@ -700,9 +700,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.7" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147100a7bea70fa20ef224a6bad700358305f5dc0f84649c53769761395b355b" +checksum = "4fbd94a32b3a7d55d3806fe27d98d3ad393050439dd05eb53ece36ec5e3d3510" dependencies = [ "base64-simd", "bytes", @@ -748,15 +748,10 @@ dependencies = [ ] [[package]] -<<<<<<< HEAD -name = "backtrace" -version = "0.3.74" -source = "registry+https://github.com/rust-lang/crates.io-index" -======= name = "axum" -version = "0.7.7" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core", @@ -773,7 +768,7 @@ dependencies = [ "pin-project-lite", "rustversion", "serde", - "sync_wrapper 1.0.1", + "sync_wrapper 1.0.2", "tower 0.5.1", "tower-layer", "tower-service", @@ -794,7 +789,7 @@ dependencies = [ "mime", "pin-project-lite", "rustversion", - "sync_wrapper 1.0.1", + "sync_wrapper 1.0.2", "tower-layer", "tower-service", ] @@ -817,7 +812,6 @@ dependencies = [ name = "backtrace" version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" ->>>>>>> main checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" dependencies = [ "addr2line", @@ -922,9 +916,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +checksum = "1a68f1f47cdf0ec8ee4b941b2eee2a80cb796db73118c0dd09ac63fbe405be22" dependencies = [ "memchr", "serde", @@ -944,9 +938,9 @@ checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" [[package]] name = "bytemuck" -version = "1.18.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" +checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a" [[package]] name = "byteorder" @@ -956,15 +950,12 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" -<<<<<<< HEAD -======= +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" dependencies = [ "serde", ] ->>>>>>> main [[package]] name = "bytes-utils" @@ -1009,9 +1000,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.28" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e80e3b6a3ab07840e1cae9b0666a63970dc28e8ed5ffbcdacbfc760c281bfc1" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ "jobserver", "libc", @@ -1030,6 +1021,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.38" @@ -1043,8 +1040,6 @@ dependencies = [ "serde", "wasm-bindgen", "windows-targets 0.52.6", -<<<<<<< HEAD -======= ] [[package]] @@ -1056,7 +1051,6 @@ dependencies = [ "chrono", "chrono-tz-build 0.2.1", "phf", ->>>>>>> main ] [[package]] @@ -1094,13 +1088,13 @@ dependencies = [ [[package]] name = "comfy-table" -version = "7.1.1" +version = "7.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" +checksum = "24f165e7b643266ea80cb858aed492ad9280e3e05ce24d4a99d7d7b889b6a4d9" dependencies = [ "strum", "strum_macros", - "unicode-width", + "unicode-width 0.2.0", ] [[package]] @@ -1148,6 +1142,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1156,9 +1160,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" +checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" dependencies = [ "libc", ] @@ -1233,9 +1237,9 @@ dependencies = [ [[package]] name = "csv" -version = "1.3.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" dependencies = [ "csv-core", "itoa", @@ -1297,7 +1301,7 @@ dependencies = [ "proc-macro2", "quote", "strsim 0.11.1", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -1319,7 +1323,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core 0.20.10", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -1706,11 +1710,6 @@ dependencies = [ ] [[package]] -<<<<<<< HEAD -name = "derive_more" -version = "0.99.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -======= name = "derive_builder" version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1728,7 +1727,7 @@ dependencies = [ "darling 0.20.10", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -1738,21 +1737,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] name = "derive_more" version = "0.99.18" source = "registry+https://github.com/rust-lang/crates.io-index" ->>>>>>> main checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" dependencies = [ "convert_case", "proc-macro2", "quote", "rustc_version 0.4.1", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -1793,6 +1791,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "doc-comment" version = "0.3.3" @@ -1819,9 +1828,9 @@ checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "encoding_rs" -version = "0.8.34" +version = "0.8.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" dependencies = [ "cfg-if", ] @@ -1882,9 +1891,9 @@ dependencies = [ [[package]] name = "fastdivide" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59668941c55e5c186b8b58c391629af56774ec768f73c08bbcd56f09348eb00b" +checksum = "9afc2bd4d5a73106dd53d10d73d3401c2f32730ba2c0b93ddb888a8983680471" [[package]] name = "fastrand" @@ -1897,9 +1906,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" +checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" [[package]] name = "fixedbitset" @@ -1915,19 +1924,16 @@ checksum = "8add37afff2d4ffa83bc748a70b4b1370984f6980768554182424ef71447c35f" dependencies = [ "bitflags 1.3.2", "rustc_version 0.4.1", -<<<<<<< HEAD -======= ] [[package]] name = "flate2" -version = "1.0.34" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" dependencies = [ "crc32fast", "miniz_oxide", ->>>>>>> main ] [[package]] @@ -1972,7 +1978,7 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7e180ac76c23b45e767bd7ae9579bc0bb458618c4bc71835926e098e61d15f8" dependencies = [ - "rustix 0.38.37", + "rustix 0.38.41", "windows-sys 0.52.0", ] @@ -2062,7 +2068,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -2175,9 +2181,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e" dependencies = [ "atomic-waker", "bytes", @@ -2221,9 +2227,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" +checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" dependencies = [ "allocator-api2", "equivalent", @@ -2356,9 +2362,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.30" +version = "0.14.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" +checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" dependencies = [ "bytes", "futures-channel", @@ -2380,14 +2386,14 @@ dependencies = [ [[package]] name = "hyper" -version = "1.4.1" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.6", + "h2 0.4.7", "http 1.1.0", "http-body 1.0.1", "httparse", @@ -2407,7 +2413,7 @@ checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http 0.2.12", - "hyper 0.14.30", + "hyper 0.14.31", "log", "rustls 0.21.12", "rustls-native-certs 0.6.3", @@ -2423,15 +2429,15 @@ checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.4.1", + "hyper 1.5.1", "hyper-util", - "rustls 0.23.14", - "rustls-native-certs 0.8.0", + "rustls 0.23.18", + "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", "tokio-rustls 0.26.0", "tower-service", - "webpki-roots 0.26.6", + "webpki-roots 0.26.7", ] [[package]] @@ -2440,7 +2446,7 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" dependencies = [ - "hyper 1.4.1", + "hyper 1.5.1", "hyper-util", "pin-project-lite", "tokio", @@ -2454,7 +2460,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper 0.14.30", + "hyper 0.14.31", "native-tls", "tokio", "tokio-native-tls", @@ -2462,22 +2468,16 @@ dependencies = [ [[package]] name = "hyper-util" -<<<<<<< HEAD -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b" -======= version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" ->>>>>>> main dependencies = [ "bytes", "futures-channel", "futures-util", "http 1.1.0", "http-body 1.0.1", - "hyper 1.4.1", + "hyper 1.5.1", "pin-project-lite", "socket2 0.5.7", "tokio", @@ -2517,6 +2517,124 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -2536,12 +2654,23 @@ dependencies = [ [[package]] name = "idna" -version = "0.5.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +dependencies = [ + "icu_normalizer", + "icu_properties", ] [[package]] @@ -2578,7 +2707,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown 0.15.0", + "hashbrown 0.15.1", "serde", ] @@ -2649,9 +2778,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.11" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "540654e97a3f4470a492cd30ff187bc95d89557a903a2bbf112e2fae98104ef2" [[package]] name = "jobserver" @@ -2664,15 +2793,9 @@ dependencies = [ [[package]] name = "js-sys" -<<<<<<< HEAD -version = "0.3.70" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" -======= -version = "0.3.71" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cb94a0ffd3f3ee755c20f7d8752f45cac88605a4dcf808abcff72873296ec7b" ->>>>>>> main +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] @@ -3180,15 +3303,15 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.159" +version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" [[package]] name = "libm" -version = "0.2.8" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" [[package]] name = "libredox" @@ -3218,6 +3341,12 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "litemap" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" + [[package]] name = "lock_api" version = "0.4.12" @@ -3233,8 +3362,6 @@ name = "log" version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" -<<<<<<< HEAD -======= [[package]] name = "lopdf" @@ -3255,7 +3382,6 @@ dependencies = [ "time", "weezl", ] ->>>>>>> main [[package]] name = "lru" @@ -3263,7 +3389,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown 0.15.0", + "hashbrown 0.15.1", ] [[package]] @@ -3408,7 +3534,7 @@ dependencies = [ "skeptic", "smallvec", "tagptr", - "thiserror", + "thiserror 1.0.69", "triomphe", "uuid", ] @@ -3449,7 +3575,7 @@ dependencies = [ "stringprep", "strsim 0.10.0", "take_mut", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-rustls 0.24.1", "tokio-util", @@ -3484,7 +3610,7 @@ dependencies = [ "openssl-probe", "openssl-sys", "schannel", - "security-framework", + "security-framework 2.11.1", "security-framework-sys", "tempfile", ] @@ -3510,11 +3636,11 @@ dependencies = [ "rustls-native-certs 0.7.3", "rustls-pemfile 2.2.0", "serde", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-rustls 0.26.0", "url", - "webpki-roots 0.26.6", + "webpki-roots 0.26.7", ] [[package]] @@ -3524,7 +3650,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a0d57c55d2d1dc62a2b1d16a0a1079eb78d67c36bdf468d582ab4482ec7002" dependencies = [ "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -3658,14 +3784,14 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.4.1", + "hyper 1.5.1", "itertools 0.13.0", "md-5", "parking_lot", "percent-encoding", "quick-xml", "rand", - "reqwest 0.12.8", + "reqwest 0.12.9", "ring", "rustls-pemfile 2.2.0", "serde", @@ -3691,9 +3817,9 @@ checksum = "e296cf87e61c9cfc1a61c3c63a0f7f286ed4554e0e22be84e8a38e1d264a2a29" [[package]] name = "openssl" -version = "0.10.66" +version = "0.10.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" +checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" dependencies = [ "bitflags 2.6.0", "cfg-if", @@ -3712,7 +3838,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -3723,9 +3849,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.103" +version = "0.9.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" +checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" dependencies = [ "cc", "libc", @@ -3741,9 +3867,9 @@ checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" [[package]] name = "ordered-float" -version = "4.3.0" +version = "4.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d501f1a72f71d3c063a6bbc8f7271fa73aa09fe5d6283b6571e2ed176a2537" +checksum = "c65ee1f9701bf938026630b455d5315f490640234259037edb259798b3bcf85e" dependencies = [ "num-traits", ] @@ -3890,29 +4016,29 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.6" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf123a161dde1e524adf36f90bc5d8d3462824a9c43553ad07a8183161189ec" +checksum = "be57f64e946e500c8ee36ef6331845d40a93055567ec57e8fae13efd33759b95" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.6" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4502d8515ca9f32f1fb543d987f63d95a14934883db45bdb48060b6b69257f8" +checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "pin-utils" @@ -3956,8 +4082,6 @@ checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ "zerocopy", ] -<<<<<<< HEAD -======= [[package]] name = "predicates" @@ -3985,23 +4109,22 @@ dependencies = [ "predicates-core", "termtree", ] ->>>>>>> main [[package]] name = "prettyplease" -version = "0.2.22" +version = "0.2.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba" +checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" dependencies = [ "proc-macro2", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] name = "proc-macro2" -version = "1.0.87" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] @@ -4043,7 +4166,7 @@ dependencies = [ "prost 0.12.6", "prost-types 0.12.6", "regex", - "syn 2.0.79", + "syn 2.0.89", "tempfile", ] @@ -4057,9 +4180,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.79", -<<<<<<< HEAD -======= + "syn 2.0.89", ] [[package]] @@ -4072,8 +4193,7 @@ dependencies = [ "itertools 0.13.0", "proc-macro2", "quote", - "syn 2.0.79", ->>>>>>> main + "syn 2.0.89", ] [[package]] @@ -4116,10 +4236,10 @@ dependencies = [ "futures-util", "prost 0.13.3", "prost-types 0.13.3", - "reqwest 0.12.8", + "reqwest 0.12.9", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "tonic", ] @@ -4157,45 +4277,49 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684" +checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" dependencies = [ "bytes", "pin-project-lite", "quinn-proto", "quinn-udp", "rustc-hash 2.0.0", - "rustls 0.23.14", + "rustls 0.23.18", "socket2 0.5.7", - "thiserror", + "thiserror 2.0.3", "tokio", "tracing", ] [[package]] name = "quinn-proto" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" +checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ "bytes", + "getrandom", "rand", "ring", "rustc-hash 2.0.0", - "rustls 0.23.14", + "rustls 0.23.18", + "rustls-pki-types", "slab", - "thiserror", + "thiserror 2.0.3", "tinyvec", "tracing", + "web-time", ] [[package]] name = "quinn-udp" -version = "0.5.5" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fe68c2e9e1a1234e218683dbdf9f9dfcb094113c5ac2b938dfcb9bab4c4140b" +checksum = "7d5a626c6807713b15cac82a6acaccd6043c9a5408c24baae07611fec3f243da" dependencies = [ + "cfg_aliases", "libc", "once_cell", "socket2 0.5.7", @@ -4310,14 +4434,14 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", - "thiserror", + "thiserror 1.0.69", ] [[package]] name = "regex" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -4327,9 +4451,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -4362,7 +4486,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.30", + "hyper 0.14.31", "hyper-tls", "ipnet", "js-sys", @@ -4390,19 +4514,19 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.8" +version = "0.12.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b" +checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" dependencies = [ "base64 0.22.1", "bytes", "futures-core", "futures-util", - "h2 0.4.6", + "h2 0.4.7", "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.4.1", + "hyper 1.5.1", "hyper-rustls 0.27.3", "hyper-util", "ipnet", @@ -4413,14 +4537,14 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.14", - "rustls-native-certs 0.8.0", + "rustls 0.23.18", + "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", - "sync_wrapper 1.0.1", + "sync_wrapper 1.0.2", "tokio", "tokio-rustls 0.26.0", "tokio-util", @@ -4430,10 +4554,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", -<<<<<<< HEAD -======= - "webpki-roots 0.26.6", ->>>>>>> main + "webpki-roots 0.26.7", "windows-registry", ] @@ -4468,7 +4589,7 @@ dependencies = [ "schemars", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-test", "tracing", @@ -4482,7 +4603,7 @@ dependencies = [ "indoc", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -4638,9 +4759,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" dependencies = [ "bitflags 2.6.0", "errno", @@ -4663,9 +4784,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.14" +version = "0.23.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "415d9944693cb90382053259f89fbb077ea730ad7273047ec63b19bc9b160ba8" +checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f" dependencies = [ "log", "once_cell", @@ -4685,13 +4806,11 @@ dependencies = [ "openssl-probe", "rustls-pemfile 1.0.4", "schannel", - "security-framework", + "security-framework 2.11.1", ] [[package]] name = "rustls-native-certs" -<<<<<<< HEAD -======= version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" @@ -4700,21 +4819,19 @@ dependencies = [ "rustls-pemfile 2.2.0", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 2.11.1", ] [[package]] name = "rustls-native-certs" ->>>>>>> main -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" dependencies = [ "openssl-probe", - "rustls-pemfile 2.2.0", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.0.1", ] [[package]] @@ -4737,9 +4854,12 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55" +checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" +dependencies = [ + "web-time", +] [[package]] name = "rustls-webpki" @@ -4764,9 +4884,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" [[package]] name = "ryu" @@ -4785,9 +4905,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" dependencies = [ "windows-sys 0.59.0", ] @@ -4822,7 +4942,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -4848,7 +4968,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.6.0", - "core-foundation", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1415a607e92bec364ea2cf9264646dcce0f91e6d65281bd6f2819cca3bf39c8" +dependencies = [ + "bitflags 2.6.0", + "core-foundation 0.10.0", "core-foundation-sys", "libc", "security-framework-sys", @@ -4856,9 +4989,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.12.0" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" +checksum = "fa39c7303dc58b5543c94d22c1766b0d31f2ee58306363ea622b10bbc075eaa2" dependencies = [ "core-foundation-sys", "libc", @@ -4890,9 +5023,9 @@ checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" dependencies = [ "serde_derive", ] @@ -4908,13 +5041,13 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -4925,14 +5058,14 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "indexmap 2.6.0", "itoa", @@ -5002,7 +5135,7 @@ dependencies = [ "darling 0.20.10", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -5177,7 +5310,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -5246,7 +5379,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -5268,9 +5401,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.79" +version = "2.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" dependencies = [ "proc-macro2", "quote", @@ -5285,13 +5418,24 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "sync_wrapper" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" dependencies = [ "futures-core", ] +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "system-configuration" version = "0.5.1" @@ -5299,7 +5443,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -5370,7 +5514,7 @@ dependencies = [ "tantivy-stacker", "tantivy-tokenizer-api", "tempfile", - "thiserror", + "thiserror 1.0.69", "time", "uuid", "winapi", @@ -5474,17 +5618,15 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" +checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" dependencies = [ "cfg-if", - "fastrand 2.1.1", + "fastrand 2.2.0", "once_cell", - "rustix 0.38.37", + "rustix 0.38.41", "windows-sys 0.59.0", -<<<<<<< HEAD -======= ] [[package]] @@ -5511,40 +5653,47 @@ checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" dependencies = [ "smawk", "unicode-linebreak", - "unicode-width", ->>>>>>> main + "unicode-width 0.1.14", ] [[package]] name = "thiserror" -<<<<<<< HEAD -version = "1.0.64" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" -======= -version = "1.0.65" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" ->>>>>>> main +checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" dependencies = [ - "thiserror-impl", + "thiserror-impl 2.0.3", ] [[package]] name = "thiserror-impl" -<<<<<<< HEAD -version = "1.0.64" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" -======= -version = "1.0.65" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" ->>>>>>> main +checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -5597,6 +5746,16 @@ dependencies = [ "crunchy", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinyvec" version = "1.8.0" @@ -5614,9 +5773,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.40.0" +version = "1.41.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" +checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" dependencies = [ "backtrace", "bytes", @@ -5638,7 +5797,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -5667,7 +5826,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.14", + "rustls 0.23.18", "rustls-pki-types", "tokio", ] @@ -5684,7 +5843,6 @@ dependencies = [ ] [[package]] -<<<<<<< HEAD name = "tokio-test" version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -5701,11 +5859,6 @@ dependencies = [ name = "tokio-util" version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -======= -name = "tokio-util" -version = "0.7.12" -source = "registry+https://github.com/rust-lang/crates.io-index" ->>>>>>> main checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" dependencies = [ "bytes", @@ -5717,11 +5870,6 @@ dependencies = [ ] [[package]] -<<<<<<< HEAD -name = "tower-service" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -======= name = "tonic" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -5733,17 +5881,17 @@ dependencies = [ "base64 0.22.1", "bytes", "flate2", - "h2 0.4.6", + "h2 0.4.7", "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.4.1", + "hyper 1.5.1", "hyper-timeout", "hyper-util", "percent-encoding", "pin-project", "prost 0.13.3", - "rustls-native-certs 0.8.0", + "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "socket2 0.5.7", "tokio", @@ -5799,7 +5947,6 @@ checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" name = "tower-service" version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" ->>>>>>> main checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] @@ -5821,7 +5968,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -5884,7 +6031,7 @@ dependencies = [ "log", "rand", "smallvec", - "thiserror", + "thiserror 1.0.69", "tinyvec", "tokio", "url", @@ -5905,7 +6052,7 @@ dependencies = [ "parking_lot", "resolv-conf", "smallvec", - "thiserror", + "thiserror 1.0.69", "tokio", "trust-dns-proto", ] @@ -5945,12 +6092,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicase" -version = "2.7.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89" -dependencies = [ - "version_check", -] +checksum = "7e51b68083f157f853b6379db119d1c1be0e6e4dec98101079dec41f6f5cf6df" [[package]] name = "unicode-bidi" @@ -5960,18 +6104,15 @@ checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" [[package]] name = "unicode-ident" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" -<<<<<<< HEAD -======= +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" [[package]] name = "unicode-linebreak" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" ->>>>>>> main [[package]] name = "unicode-normalization" @@ -6000,6 +6141,12 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" + [[package]] name = "untrusted" version = "0.9.0" @@ -6008,12 +6155,12 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.2" +version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" dependencies = [ "form_urlencoded", - "idna 0.5.0", + "idna 1.0.3", "percent-encoding", ] @@ -6023,17 +6170,29 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + [[package]] name = "utf8-ranges" version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcfc827f90e53a02eaef5e535ee14266c1d569214c6aa70133a624d8a3164ba" +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "getrandom", "serde", @@ -6096,15 +6255,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -<<<<<<< HEAD -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" -======= -version = "0.2.94" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef073ced962d62984fb38a36e5fdc1a2b23c9e0e1fa0689bb97afa4202ef6887" ->>>>>>> main +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -6113,36 +6266,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -<<<<<<< HEAD -version = "0.2.93" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" -======= -version = "0.2.94" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4bfab14ef75323f4eb75fa52ee0a3fb59611977fd3240da19b2cf36ff85030e" ->>>>>>> main +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -<<<<<<< HEAD -version = "0.4.43" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" -======= -version = "0.4.44" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65471f79c1022ffa5291d33520cbbb53b7687b01c2f8e83b57d102eed7ed479d" ->>>>>>> main +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", @@ -6152,15 +6293,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -<<<<<<< HEAD -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" -======= -version = "0.2.94" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7bec9830f60924d9ceb3ef99d55c155be8afa76954edffbb5936ff4509474e7" ->>>>>>> main +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6168,40 +6303,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -<<<<<<< HEAD -version = "0.2.93" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" -======= -version = "0.2.94" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c74f6e152a76a2ad448e223b0fc0b6b5747649c3d769cc6bf45737bf97d0ed6" ->>>>>>> main +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -<<<<<<< HEAD -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" -======= -version = "0.2.94" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a42f6c679374623f295a8623adfe63d9284091245c3504bde47c17a3ce2777d9" ->>>>>>> main +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "wasm-streams" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e072d4e72f700fb3443d8fe94a39315df013eef1104903cdb0a2abd322bbecd" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" dependencies = [ "futures-util", "js-sys", @@ -6212,15 +6335,19 @@ dependencies = [ [[package]] name = "web-sys" -<<<<<<< HEAD -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" -======= -version = "0.3.71" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44188d185b5bdcae1052d08bcbcf9091a5524038d4572cc4f4f2bb9d5554ddd9" ->>>>>>> main +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" dependencies = [ "js-sys", "wasm-bindgen", @@ -6234,9 +6361,9 @@ checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" [[package]] name = "webpki-roots" -version = "0.26.6" +version = "0.26.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" +checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" dependencies = [ "rustls-pki-types", ] @@ -6481,6 +6608,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + [[package]] name = "wyz" version = "0.5.1" @@ -6496,6 +6635,30 @@ version = "0.13.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" +[[package]] +name = "yoke" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.7.35" @@ -6514,7 +6677,28 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", +] + +[[package]] +name = "zerofrom" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", + "synstructure", ] [[package]] @@ -6523,6 +6707,28 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "zstd" version = "0.13.2" diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index aa1fda65..290e50ac 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -54,3 +54,11 @@ required-features = ["derive"] [[example]] name = "vector_search_cohere" required-features = ["derive"] + +[[example]] +name = "gemini_embeddings" +required-features = ["derive"] + +[[example]] +name = "xai_embeddings" +required-features = ["derive"] diff --git a/rig-core/examples/gemini_embeddings.rs b/rig-core/examples/gemini_embeddings.rs index 4ce24636..a8d9f3a6 100644 --- a/rig-core/examples/gemini_embeddings.rs +++ b/rig-core/examples/gemini_embeddings.rs @@ -1,4 +1,12 @@ use rig::providers::gemini; +use rig::Embed; + +#[derive(Embed, Debug)] +struct Greetings { + id: String, + #[embed] + message: String, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -8,8 +16,14 @@ async fn main() -> Result<(), anyhow::Error> { let embeddings = client .embeddings(gemini::embedding::EMBEDDING_001) - .simple_document("doc0", "Hello, world!") - .simple_document("doc1", "Goodbye, world!") + .document(Greetings { + id: "doc0".to_string(), + message: "Hello, world!".to_string(), + })? + .document(Greetings { + id: "doc1".to_string(), + message: "Goodbye, world!".to_string(), + })? .build() .await .expect("Failed to embed documents"); diff --git a/rig-core/examples/xai_embeddings.rs b/rig-core/examples/xai_embeddings.rs index ba24a9b0..09c39796 100644 --- a/rig-core/examples/xai_embeddings.rs +++ b/rig-core/examples/xai_embeddings.rs @@ -1,4 +1,12 @@ use rig::providers::xai; +use rig::Embed; + +#[derive(Embed, Debug)] +struct Greetings { + id: String, + #[embed] + message: String, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -7,8 +15,14 @@ async fn main() -> Result<(), anyhow::Error> { let embeddings = client .embeddings(xai::embedding::EMBEDDING_V1) - .simple_document("doc0", "Hello, world!") - .simple_document("doc1", "Goodbye, world!") + .document(Greetings { + id: "doc0".to_string(), + message: "Hello, world!".to_string(), + })? + .document(Greetings { + id: "doc1".to_string(), + message: "Goodbye, world!".to_string(), + })? .build() .await .expect("Failed to embed documents"); diff --git a/rig-core/src/embeddings/embed.rs b/rig-core/src/embeddings/embed.rs index 2fd26c2a..6b30bbc0 100644 --- a/rig-core/src/embeddings/embed.rs +++ b/rig-core/src/embeddings/embed.rs @@ -1,4 +1,13 @@ -//! The module defines the [Embed] trait, which must be implemented for types that can be embedded by the `EmbeddingsBuilder`. +//! The module defines the [Embed] trait, which must be implemented for types +//! that can be embedded by the [crate::embeddings::EmbeddingsBuilder]. +//! +//! The module also defines the [EmbedError] struct which is used for when the `embed` +//! method of the `Embed` trait fails. +//! +//! The module also defines the [TextEmbedder] struct which accumulates string values that need to be embedded. +//! It is used directly with the `Embed` trait. +//! +//! Finally, the module implements [Embed] for many common primitive types. /// Error type used for when the `embed` method fo the `Embed` trait fails. /// Used by default implementations of `Embed` for common types. @@ -20,7 +29,7 @@ impl EmbedError { /// use std::env; /// /// use serde::{Deserialize, Serialize}; -/// use rig::Embed; +/// use rig::{Embed, embeddings::{TextEmbedder, EmbedError, to_texts}}; /// /// struct FakeDefinition { /// id: String, @@ -42,6 +51,14 @@ impl EmbedError { /// Ok(()) /// } /// } +/// +/// let fake_definition = FakeDefinition { +/// id: "1".to_string(), +/// word: "apple".to_string(), +/// definitions: "a fruit, a tech company".to_string(), +/// }; +/// +/// assert_eq!(to_texts(fake_definition).unwrap(), vec!["a fruit", " a tech company"]); /// ``` pub trait Embed { fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError>; @@ -55,6 +72,7 @@ pub struct TextEmbedder { } impl TextEmbedder { + /// Adds input `text` string to the list of texts in the `TextEmbedder` that need to be embedded. pub fn embed(&mut self, text: String) { self.texts.push(text); } diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index d033f57e..a7d5f89f 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -1,7 +1,5 @@ //! The module defines the [EmbeddingModel] trait, which represents an embedding model that can -//! generate embeddings for documents. It also provides an implementation of the [crate::embeddings::EmbeddingsBuilder] -//! struct, which allows users to build collections of document embeddings using different embedding -//! models and document sources. +//! generate embeddings for documents. //! //! The module also defines the [Embedding] struct, which represents a single document embedding. //! diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index bcea7c7d..0b45c66e 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -1,3 +1,5 @@ +//! The module defines the [ToolSchema] struct, which is used to embed an object that implements [crate::tool::Tool] + use crate::{tool::ToolEmbeddingDyn, Embed}; use serde::Serialize; diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 29412d25..4fe84a6b 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -84,8 +84,8 @@ pub mod completion; pub mod embeddings; pub mod extractor; pub(crate) mod json_utils; -pub mod one_or_many; pub mod loaders; +pub mod one_or_many; pub mod providers; pub mod tool; pub mod vector_store; diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs index c22e6996..04316dfe 100644 --- a/rig-core/src/providers/gemini/client.rs +++ b/rig-core/src/providers/gemini/client.rs @@ -2,6 +2,7 @@ use crate::{ agent::AgentBuilder, embeddings::{self}, extractor::ExtractorBuilder, + Embed, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -104,7 +105,10 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder<EmbeddingModel> { + pub fn embeddings<D: Embed>( + &self, + model: &str, + ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> { embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) } diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs index 1249387a..c2b76e02 100644 --- a/rig-core/src/providers/gemini/embedding.rs +++ b/rig-core/src/providers/gemini/embedding.rs @@ -41,7 +41,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { } } - async fn embed_documents( + async fn embed_texts( &self, documents: impl IntoIterator<Item = String> + Send, ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> { diff --git a/rig-core/src/providers/xai/client.rs b/rig-core/src/providers/xai/client.rs index e03c6978..6af7cd31 100644 --- a/rig-core/src/providers/xai/client.rs +++ b/rig-core/src/providers/xai/client.rs @@ -2,6 +2,7 @@ use crate::{ agent::AgentBuilder, embeddings::{self}, extractor::ExtractorBuilder, + Embed, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -113,7 +114,10 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder<EmbeddingModel> { + pub fn embeddings<D: Embed>( + &self, + model: &str, + ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> { embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) } diff --git a/rig-core/src/providers/xai/embedding.rs b/rig-core/src/providers/xai/embedding.rs index 1c588071..f0ad1c92 100644 --- a/rig-core/src/providers/xai/embedding.rs +++ b/rig-core/src/providers/xai/embedding.rs @@ -69,7 +69,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { self.ndims } - async fn embed_documents( + async fn embed_texts( &self, documents: impl IntoIterator<Item = String>, ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> { diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index f2112118..931946eb 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -181,8 +181,10 @@ impl<M: EmbeddingModel + Sync, D: Serialize + Sync + Send + Eq> VectorStoreIndex Ok(( distance.0, id.clone(), - serde_json::from_value(doc.clone()) - .map_err(VectorStoreError::JsonError)?, + serde_json::from_str( + &serde_json::to_string(doc).map_err(VectorStoreError::JsonError)?, + ) + .map_err(VectorStoreError::JsonError)?, )) }) .collect::<Result<Vec<_>, _>>() diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index bc0fc7fc..7567bc60 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -181,8 +181,8 @@ impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorIndex<M> /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; /// use rig::embeddings::EmbeddingModel; /// - /// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. - /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. + /// let table: lancedb::Table = db.create_table("fake_definitions"); // <-- Replace with your lancedb table here. + /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// /// // Query the index diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index cb6c6fcf..ecdb7a9f 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -96,12 +96,8 @@ async fn main() -> Result<(), anyhow::Error> { // Create a vector index on our vector store. // Note: a vector index called "vector_index" must exist on the MongoDB collection you are querying. // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = MongoDbVectorIndex::new( - collection, - model, - "vector_index", - SearchParams::new("embedding"), - ); + let index = + MongoDbVectorIndex::new(collection, model, "vector_index", SearchParams::new()).await?; // Query the index let results = index diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 5cc52a45..813818e9 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -7,11 +7,6 @@ use rig::{ }; use serde::{Deserialize, Serialize}; -/// A MongoDB vector store. -pub struct MongoDbVectorStore { - collection: mongodb::Collection<bson::Document>, -} - #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct SearchIndex { @@ -25,8 +20,8 @@ struct SearchIndex { } impl SearchIndex { - async fn get_search_index( - collection: mongodb::Collection<bson::Document>, + async fn get_search_index<C>( + collection: mongodb::Collection<C>, index_name: &str, ) -> Result<SearchIndex, VectorStoreError> { collection @@ -100,7 +95,6 @@ impl<M: EmbeddingModel, C> MongoDbVectorIndex<M, C> { filter, exact, num_candidates, - path, } = &self.search_params; doc! { @@ -169,21 +163,20 @@ impl<M: EmbeddingModel, C> MongoDbVectorIndex<M, C> { /// See [MongoDB Vector Search](`https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/`) for more information /// on each of the fields +#[derive(Default)] pub struct SearchParams { filter: mongodb::bson::Document, - path: String, exact: Option<bool>, num_candidates: Option<u32>, } impl SearchParams { /// Initializes a new `SearchParams` with default values. - pub fn new(path: &str) -> Self { + pub fn new() -> Self { Self { filter: doc! {}, exact: None, num_candidates: None, - path: path.to_string(), } } diff --git a/rig-neo4j/Cargo.toml b/rig-neo4j/Cargo.toml index a0a94633..aae275c7 100644 --- a/rig-neo4j/Cargo.toml +++ b/rig-neo4j/Cargo.toml @@ -22,3 +22,7 @@ anyhow = "1.0.86" tokio = { version = "1.38.0", features = ["macros"] } textwrap = { version = "0.16.1"} term_size = { version = "0.3.2"} + +[[example]] +name = "vector_search_simple" +required-features = ["rig-core/derive"] \ No newline at end of file diff --git a/rig-neo4j/examples/vector_search_simple.rs b/rig-neo4j/examples/vector_search_simple.rs index 2cb0030d..0d3acf81 100644 --- a/rig-neo4j/examples/vector_search_simple.rs +++ b/rig-neo4j/examples/vector_search_simple.rs @@ -13,12 +13,20 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex as _, + Embed, }; use rig_neo4j::{ vector_index::{IndexConfig, SearchParams}, Neo4jClient, ToBoltType, }; +#[derive(Embed, Clone, Debug)] +pub struct FakeDefinition { + pub id: String, + #[embed] + pub definition: String, +} + #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client @@ -36,9 +44,18 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .document(FakeDefinition { + id: "doc0".to_string(), + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + })? + .document(FakeDefinition { + id: "doc1".to_string(), + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + })? + .document(FakeDefinition { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + })? .build() .await?; @@ -54,7 +71,7 @@ async fn main() -> Result<(), anyhow::Error> { } let create_nodes = futures::stream::iter(embeddings) - .map(|doc| { + .map(|(doc, embeddings)| { neo4j_client.graph.run( neo4rs::query( " @@ -68,8 +85,8 @@ async fn main() -> Result<(), anyhow::Error> { .param("id", doc.id) // Here we use the first embedding but we could use any of them. // Neo4j only takes primitive types or arrays as properties. - .param("embedding", doc.embeddings[0].vec.clone()) - .param("document", doc.document.to_bolt_type()), + .param("embedding", embeddings.first().vec.clone()) + .param("document", doc.definition.to_bolt_type()), ) }) .buffer_unordered(3) diff --git a/rig-neo4j/src/vector_index.rs b/rig-neo4j/src/vector_index.rs index db6fd3df..bf39644d 100644 --- a/rig-neo4j/src/vector_index.rs +++ b/rig-neo4j/src/vector_index.rs @@ -259,7 +259,7 @@ impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for Neo4jVec query: &str, n: usize, ) -> Result<Vec<(f64, String, T)>, VectorStoreError> { - let prompt_embedding = self.embedding_model.embed_document(query).await?; + let prompt_embedding = self.embedding_model.embed_text(query).await?; let query = self.build_vector_search_query(prompt_embedding, true, n); let rows = self.execute_and_collect::<RowResultNode<T>>(query).await?; @@ -279,7 +279,7 @@ impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for Neo4jVec query: &str, n: usize, ) -> Result<Vec<(f64, String)>, VectorStoreError> { - let prompt_embedding = self.embedding_model.embed_document(query).await?; + let prompt_embedding = self.embedding_model.embed_text(query).await?; let query = self.build_vector_search_query(prompt_embedding, false, n); diff --git a/rig-qdrant/Cargo.toml b/rig-qdrant/Cargo.toml index 4a7360a9..35c4b96e 100644 --- a/rig-qdrant/Cargo.toml +++ b/rig-qdrant/Cargo.toml @@ -16,3 +16,7 @@ qdrant-client = "1.12.1" [dev-dependencies] tokio = { version = "1.40.0", features = ["rt-multi-thread"] } anyhow = "1.0.89" + +[[example]] +name = "qdrant_vector_search" +required-features = ["rig-core/derive"] \ No newline at end of file diff --git a/rig-qdrant/examples/qdrant_vector_search.rs b/rig-qdrant/examples/qdrant_vector_search.rs index c9148d69..b1dc2d96 100644 --- a/rig-qdrant/examples/qdrant_vector_search.rs +++ b/rig-qdrant/examples/qdrant_vector_search.rs @@ -19,10 +19,18 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, + Embed, }; use rig_qdrant::QdrantVectorStore; use serde_json::json; +#[derive(Embed)] +struct FakeDefinition { + id: String, + #[embed] + definition: String, +} + #[tokio::main] async fn main() -> Result<(), anyhow::Error> { const COLLECTION_NAME: &str = "rig-collection"; @@ -49,21 +57,30 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let documents = EmbeddingsBuilder::new(model.clone()) - .simple_document("0981d983-a5f8-49eb-89ea-f7d3b2196d2e", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("62a36d43-80b6-4fd6-990c-f75bb02287d1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("f9e17d59-32e5-440c-be02-b2759a654824", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .document(FakeDefinition { + id: "0981d983-a5f8-49eb-89ea-f7d3b2196d2e".to_string(), + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + })? + .document(FakeDefinition { + id: "62a36d43-80b6-4fd6-990c-f75bb02287d1".to_string(), + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + })? + .document(FakeDefinition { + id: "f9e17d59-32e5-440c-be02-b2759a654824".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + })? .build() .await?; let points: Vec<PointStruct> = documents .into_iter() - .map(|d| { - let vec: Vec<f32> = d.embeddings[0].vec.iter().map(|&x| x as f32).collect(); + .map(|(d, embeddings)| { + let vec: Vec<f32> = embeddings.first().vec.iter().map(|&x| x as f32).collect(); PointStruct::new( d.id, vec, Payload::try_from(json!({ - "document": d.document, + "document": d.definition, })) .unwrap(), ) diff --git a/rig-qdrant/src/lib.rs b/rig-qdrant/src/lib.rs index e88878cc..666d0a4f 100644 --- a/rig-qdrant/src/lib.rs +++ b/rig-qdrant/src/lib.rs @@ -36,7 +36,7 @@ impl<M: EmbeddingModel> QdrantVectorStore<M> { /// Embed query based on `QdrantVectorStore` model and modify the vector in the required format. async fn generate_query_vector(&self, query: &str) -> Result<Vec<f32>, VectorStoreError> { - let embedding = self.model.embed_document(query).await?; + let embedding = self.model.embed_text(query).await?; Ok(embedding.vec.iter().map(|&x| x as f32).collect()) } From 4363671912355de0d0da7b105df3f34782b13b47 Mon Sep 17 00:00:00 2001 From: Garance Buricatu <garance.mary@gmail.com> Date: Fri, 22 Nov 2024 22:53:32 -0500 Subject: [PATCH 83/91] fix: cargo clippy in examples --- rig-core/examples/gemini_embeddings.rs | 3 --- rig-core/examples/xai_embeddings.rs | 3 --- 2 files changed, 6 deletions(-) diff --git a/rig-core/examples/gemini_embeddings.rs b/rig-core/examples/gemini_embeddings.rs index a8d9f3a6..6f8badbe 100644 --- a/rig-core/examples/gemini_embeddings.rs +++ b/rig-core/examples/gemini_embeddings.rs @@ -3,7 +3,6 @@ use rig::Embed; #[derive(Embed, Debug)] struct Greetings { - id: String, #[embed] message: String, } @@ -17,11 +16,9 @@ async fn main() -> Result<(), anyhow::Error> { let embeddings = client .embeddings(gemini::embedding::EMBEDDING_001) .document(Greetings { - id: "doc0".to_string(), message: "Hello, world!".to_string(), })? .document(Greetings { - id: "doc1".to_string(), message: "Goodbye, world!".to_string(), })? .build() diff --git a/rig-core/examples/xai_embeddings.rs b/rig-core/examples/xai_embeddings.rs index 09c39796..a127c389 100644 --- a/rig-core/examples/xai_embeddings.rs +++ b/rig-core/examples/xai_embeddings.rs @@ -3,7 +3,6 @@ use rig::Embed; #[derive(Embed, Debug)] struct Greetings { - id: String, #[embed] message: String, } @@ -16,11 +15,9 @@ async fn main() -> Result<(), anyhow::Error> { let embeddings = client .embeddings(xai::embedding::EMBEDDING_V1) .document(Greetings { - id: "doc0".to_string(), message: "Hello, world!".to_string(), })? .document(Greetings { - id: "doc1".to_string(), message: "Goodbye, world!".to_string(), })? .build() From 2d6d7c4313dec0f4477a26d5c2c8de957e5383bd Mon Sep 17 00:00:00 2001 From: cvauclair <cvauclair@protonmail.com> Date: Fri, 29 Nov 2024 10:23:49 -0500 Subject: [PATCH 84/91] Feat: small improvements + fixes + tests (#128) * docs: Make examples+docstrings a bit more realistic * feat: Add Embed implementation for &impl Embed * test: Reorganize tests * misc: Add `derive` feature to `all` feature flag * test: Fix dead code warning * test: Improve embed macro tests * test: Add additional embed macro test * docs: Add logging output to rag example * docs: Fix looging output in tools example * feat: Improve token usage log messages * test: Small changes to embedbing builder tests * style: cargo fmt * fix: Clippy + docstrings * docs: Fix docstring * test: Fix test --- rig-core/Cargo.toml | 2 +- rig-core/examples/rag.rs | 32 ++- rig-core/examples/rag_dynamic_tools.rs | 5 - rig-core/examples/vector_search.rs | 10 +- rig-core/examples/vector_search_cohere.rs | 10 +- rig-core/src/embeddings/builder.rs | 125 +++++----- rig-core/src/embeddings/embed.rs | 16 +- rig-core/src/lib.rs | 2 +- rig-core/src/loaders/file.rs | 4 +- rig-core/src/loaders/pdf.rs | 4 +- rig-core/src/providers/openai.rs | 6 +- rig-core/src/providers/perplexity.rs | 2 +- rig-core/tests/embed_macro.rs | 225 +++++++++++------- rig-lancedb/examples/fixtures/lib.rs | 18 +- .../examples/vector_search_local_ann.rs | 6 +- rig-lancedb/examples/vector_search_s3_ann.rs | 6 +- rig-lancedb/src/utils/deserializer.rs | 6 +- rig-mongodb/examples/vector_search_mongodb.rs | 12 +- rig-neo4j/examples/vector_search_simple.rs | 8 +- rig-qdrant/examples/qdrant_vector_search.rs | 8 +- 20 files changed, 281 insertions(+), 226 deletions(-) diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 290e50ac..f8c44e8f 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -35,8 +35,8 @@ tracing-subscriber = "0.3.18" tokio-test = "0.4.4" [features] +all = ["derive", "pdf"] derive = ["dep:rig-derive"] -all = ["pdf"] pdf = ["dep:lopdf"] [[test]] diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index cecd20ce..376c37db 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -10,17 +10,24 @@ use rig::{ use serde::Serialize; // Data to be RAGged. -// A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `FakeDefinition` +// A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `WordDefinition` // and tag that field with `#[embed]`. #[derive(Embed, Serialize, Clone, Debug, Eq, PartialEq, Default)] -struct FakeDefinition { +struct WordDefinition { id: String, + word: String, #[embed] definitions: Vec<String>, } #[tokio::main] async fn main() -> Result<(), anyhow::Error> { + // Initialize tracing + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .with_target(false) + .init(); + // Create OpenAI client let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); @@ -30,25 +37,28 @@ async fn main() -> Result<(), anyhow::Error> { // Generate embeddings for the definitions of all the documents using the specified embedding model. let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .documents(vec![ - FakeDefinition { + WordDefinition { id: "doc0".to_string(), + word: "flurbo".to_string(), definitions: vec![ - "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets.".to_string(), - "Definition of a *flurbo*: A fictional digital currency that originated in the animated series Rick and Morty.".to_string() + "1. *flurbo* (name): A flurbo is a green alien that lives on cold planets.".to_string(), + "2. *flurbo* (name): A fictional digital currency that originated in the animated series Rick and Morty.".to_string() ] }, - FakeDefinition { + WordDefinition { id: "doc1".to_string(), + word: "glarb-glarb".to_string(), definitions: vec![ - "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "Definition of a *glarb-glarb*: A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + "1. *glarb-glarb* (noun): A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "2. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() ] }, - FakeDefinition { + WordDefinition { id: "doc2".to_string(), + word: "linglingdong".to_string(), definitions: vec![ - "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), - "Definition of a *linglingdong*: A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() + "1. *linglingdong* (noun): A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + "2. *linglingdong* (noun): A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() ] }, ])? diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 459b017b..bc92f7c5 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -137,11 +137,6 @@ async fn main() -> Result<(), anyhow::Error> { .with_max_level(tracing::Level::INFO) // disable printing the name of the module in every log line. .with_target(false) - // this needs to be set to false, otherwise ANSI color codes will - // show up in a confusing manner in CloudWatch logs. - .with_ansi(false) - // disabling time is handy because CloudWatch will add the ingestion time. - .without_time() .init(); // Create OpenAI client diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 622fa0ff..e8cbf894 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. #[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -struct FakeDefinition { +struct WordDefinition { id: String, word: String, #[embed] @@ -28,7 +28,7 @@ async fn main() -> Result<(), anyhow::Error> { let embeddings = EmbeddingsBuilder::new(model.clone()) .documents(vec![ - FakeDefinition { + WordDefinition { id: "doc0".to_string(), word: "flurbo".to_string(), definitions: vec![ @@ -36,7 +36,7 @@ async fn main() -> Result<(), anyhow::Error> { "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() ] }, - FakeDefinition { + WordDefinition { id: "doc1".to_string(), word: "glarb-glarb".to_string(), definitions: vec![ @@ -44,7 +44,7 @@ async fn main() -> Result<(), anyhow::Error> { "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() ] }, - FakeDefinition { + WordDefinition { id: "doc2".to_string(), word: "linglingdong".to_string(), definitions: vec![ @@ -61,7 +61,7 @@ async fn main() -> Result<(), anyhow::Error> { .index(model); let results = index - .top_n::<FakeDefinition>("I need to buy something in a fictional universe. What type of money can I use for this?", 1) + .top_n::<WordDefinition>("I need to buy something in a fictional universe. What type of money can I use for this?", 1) .await? .into_iter() .map(|(score, id, doc)| (score, id, doc.word)) diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index f3a97498..aace89fa 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. #[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -struct FakeDefinition { +struct WordDefinition { id: String, word: String, #[embed] @@ -29,7 +29,7 @@ async fn main() -> Result<(), anyhow::Error> { let embeddings = EmbeddingsBuilder::new(document_model.clone()) .documents(vec![ - FakeDefinition { + WordDefinition { id: "doc0".to_string(), word: "flurbo".to_string(), definitions: vec![ @@ -37,7 +37,7 @@ async fn main() -> Result<(), anyhow::Error> { "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() ] }, - FakeDefinition { + WordDefinition { id: "doc1".to_string(), word: "glarb-glarb".to_string(), definitions: vec![ @@ -45,7 +45,7 @@ async fn main() -> Result<(), anyhow::Error> { "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() ] }, - FakeDefinition { + WordDefinition { id: "doc2".to_string(), word: "linglingdong".to_string(), definitions: vec![ @@ -62,7 +62,7 @@ async fn main() -> Result<(), anyhow::Error> { .index(search_model); let results = index - .top_n::<FakeDefinition>( + .top_n::<WordDefinition>( "Which instrument is found in the Nebulon Mountain Ranges?", 1, ) diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 8f0a5dd1..5a4b63c9 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -13,40 +13,7 @@ use crate::{ /// Builder for creating a collection of embeddings from a vector of documents of type `T`. /// Accumulate documents such that they can be embedded in a single batch to limit api calls to the provider. -pub struct EmbeddingsBuilder<M: EmbeddingModel, T: Embed> { - model: M, - documents: Vec<(T, Vec<String>)>, -} - -impl<M: EmbeddingModel, T: Embed> EmbeddingsBuilder<M, T> { - /// Create a new embedding builder with the given embedding model - pub fn new(model: M) -> Self { - Self { - model, - documents: vec![], - } - } - - /// Add a document that implements `Embed` to the builder. - pub fn document(mut self, document: T) -> Result<Self, EmbedError> { - let mut embedder = TextEmbedder::default(); - document.embed(&mut embedder)?; - - self.documents.push((document, embedder.texts)); - - Ok(self) - } - - /// Add many documents that implement `Embed` to the builder. - pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> { - let builder = documents - .into_iter() - .try_fold(self, |builder, doc| builder.document(doc))?; - - Ok(builder) - } -} - +/// /// # Example /// ```rust /// use std::env; @@ -62,7 +29,7 @@ impl<M: EmbeddingModel, T: Embed> EmbeddingsBuilder<M, T> { /// // Shape of data that needs to be RAG'ed. /// // The definition field will be used to generate embeddings. /// #[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -/// struct FakeDefinition { +/// struct WordDefinition { /// id: String, /// word: String, /// #[embed] @@ -77,7 +44,7 @@ impl<M: EmbeddingModel, T: Embed> EmbeddingsBuilder<M, T> { /// /// let embeddings = EmbeddingsBuilder::new(model.clone()) /// .documents(vec![ -/// FakeDefinition { +/// WordDefinition { /// id: "doc0".to_string(), /// word: "flurbo".to_string(), /// definitions: vec![ @@ -85,7 +52,7 @@ impl<M: EmbeddingModel, T: Embed> EmbeddingsBuilder<M, T> { /// "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() /// ] /// }, -/// FakeDefinition { +/// WordDefinition { /// id: "doc1".to_string(), /// word: "glarb-glarb".to_string(), /// definitions: vec![ @@ -93,7 +60,7 @@ impl<M: EmbeddingModel, T: Embed> EmbeddingsBuilder<M, T> { /// "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() /// ] /// }, -/// FakeDefinition { +/// WordDefinition { /// id: "doc2".to_string(), /// word: "linglingdong".to_string(), /// definitions: vec![ @@ -105,6 +72,40 @@ impl<M: EmbeddingModel, T: Embed> EmbeddingsBuilder<M, T> { /// .build() /// .await?; /// ``` +pub struct EmbeddingsBuilder<M: EmbeddingModel, T: Embed> { + model: M, + documents: Vec<(T, Vec<String>)>, +} + +impl<M: EmbeddingModel, T: Embed> EmbeddingsBuilder<M, T> { + /// Create a new embedding builder with the given embedding model + pub fn new(model: M) -> Self { + Self { + model, + documents: vec![], + } + } + + /// Add a document that implements `Embed` to the builder. + pub fn document(mut self, document: T) -> Result<Self, EmbedError> { + let mut embedder = TextEmbedder::default(); + document.embed(&mut embedder)?; + + self.documents.push((document, embedder.texts)); + + Ok(self) + } + + /// Add many documents that implement `Embed` to the builder. + pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> { + let builder = documents + .into_iter() + .try_fold(self, |builder, doc| builder.document(doc))?; + + Ok(builder) + } +} + impl<M: EmbeddingModel, T: Embed + Send> EmbeddingsBuilder<M, T> { /// Generate embeddings for all documents in the builder. /// Returns a vector of tuples, where the first element is the document and the second element is the embeddings (either one embedding or many). @@ -174,9 +175,9 @@ mod tests { use super::EmbeddingsBuilder; #[derive(Clone)] - struct FakeModel; + struct Model; - impl EmbeddingModel for FakeModel { + impl EmbeddingModel for Model { const MAX_DOCUMENTS: usize = 5; fn ndims(&self) -> usize { @@ -198,12 +199,12 @@ mod tests { } #[derive(Clone, Debug)] - struct FakeDefinition { + struct WordDefinition { id: String, definitions: Vec<String>, } - impl Embed for FakeDefinition { + impl Embed for WordDefinition { fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { for definition in &self.definitions { embedder.embed(definition.clone()); @@ -212,16 +213,16 @@ mod tests { } } - fn fake_definitions_multiple_text() -> Vec<FakeDefinition> { + fn definitions_multiple_text() -> Vec<WordDefinition> { vec![ - FakeDefinition { + WordDefinition { id: "doc0".to_string(), definitions: vec![ "A green alien that lives on cold planets.".to_string(), "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() ] }, - FakeDefinition { + WordDefinition { id: "doc1".to_string(), definitions: vec![ "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), @@ -231,13 +232,13 @@ mod tests { ] } - fn fake_definitions_multiple_text_2() -> Vec<FakeDefinition> { + fn definitions_multiple_text_2() -> Vec<WordDefinition> { vec![ - FakeDefinition { + WordDefinition { id: "doc2".to_string(), definitions: vec!["Another fake definitions".to_string()], }, - FakeDefinition { + WordDefinition { id: "doc3".to_string(), definitions: vec!["Some fake definition".to_string()], }, @@ -245,25 +246,25 @@ mod tests { } #[derive(Clone, Debug)] - struct FakeDefinitionSingle { + struct WordDefinitionSingle { id: String, definition: String, } - impl Embed for FakeDefinitionSingle { + impl Embed for WordDefinitionSingle { fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { embedder.embed(self.definition.clone()); Ok(()) } } - fn fake_definitions_single_text() -> Vec<FakeDefinitionSingle> { + fn definitions_single_text() -> Vec<WordDefinitionSingle> { vec![ - FakeDefinitionSingle { + WordDefinitionSingle { id: "doc0".to_string(), definition: "A green alien that lives on cold planets.".to_string(), }, - FakeDefinitionSingle { + WordDefinitionSingle { id: "doc1".to_string(), definition: "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), } @@ -272,9 +273,9 @@ mod tests { #[tokio::test] async fn test_build_multiple_text() { - let fake_definitions = fake_definitions_multiple_text(); + let fake_definitions = definitions_multiple_text(); - let fake_model = FakeModel; + let fake_model = Model; let mut result = EmbeddingsBuilder::new(fake_model) .documents(fake_definitions) .unwrap() @@ -306,9 +307,9 @@ mod tests { #[tokio::test] async fn test_build_single_text() { - let fake_definitions = fake_definitions_single_text(); + let fake_definitions = definitions_single_text(); - let fake_model = FakeModel; + let fake_model = Model; let mut result = EmbeddingsBuilder::new(fake_model) .documents(fake_definitions) .unwrap() @@ -340,10 +341,10 @@ mod tests { #[tokio::test] async fn test_build_multiple_and_single_text() { - let fake_definitions = fake_definitions_multiple_text(); - let fake_definitions_single = fake_definitions_multiple_text_2(); + let fake_definitions = definitions_multiple_text(); + let fake_definitions_single = definitions_multiple_text_2(); - let fake_model = FakeModel; + let fake_model = Model; let mut result = EmbeddingsBuilder::new(fake_model) .documents(fake_definitions) .unwrap() @@ -377,10 +378,10 @@ mod tests { #[tokio::test] async fn test_build_string() { - let bindings = fake_definitions_multiple_text(); + let bindings = definitions_multiple_text(); let fake_definitions = bindings.iter().map(|def| def.definitions.clone()); - let fake_model = FakeModel; + let fake_model = Model; let mut result = EmbeddingsBuilder::new(fake_model) .documents(fake_definitions) .unwrap() diff --git a/rig-core/src/embeddings/embed.rs b/rig-core/src/embeddings/embed.rs index 6b30bbc0..480a2930 100644 --- a/rig-core/src/embeddings/embed.rs +++ b/rig-core/src/embeddings/embed.rs @@ -29,15 +29,15 @@ impl EmbedError { /// use std::env; /// /// use serde::{Deserialize, Serialize}; -/// use rig::{Embed, embeddings::{TextEmbedder, EmbedError, to_texts}}; +/// use rig::{Embed, embeddings::{TextEmbedder, EmbedError}}; /// -/// struct FakeDefinition { +/// struct WordDefinition { /// id: String, /// word: String, /// definitions: String, /// } /// -/// impl Embed for FakeDefinition { +/// impl Embed for WordDefinition { /// fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { /// // Embeddings only need to be generated for `definition` field. /// // Split the definitions by comma and collect them into a vector of strings. @@ -52,13 +52,13 @@ impl EmbedError { /// } /// } /// -/// let fake_definition = FakeDefinition { +/// let fake_definition = WordDefinition { /// id: "1".to_string(), /// word: "apple".to_string(), /// definitions: "a fruit, a tech company".to_string(), /// }; /// -/// assert_eq!(to_texts(fake_definition).unwrap(), vec!["a fruit", " a tech company"]); +/// assert_eq!(embeddings::to_texts(fake_definition).unwrap(), vec!["a fruit", " a tech company"]); /// ``` pub trait Embed { fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError>; @@ -174,6 +174,12 @@ impl Embed for serde_json::Value { } } +impl<T: Embed> Embed for &T { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + (*self).embed(embedder) + } +} + impl<T: Embed> Embed for Vec<T> { fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { for item in self { diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 4fe84a6b..f1b5427b 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -91,7 +91,7 @@ pub mod tool; pub mod vector_store; // Re-export commonly used types and traits -pub use embeddings::{to_texts, Embed}; +pub use embeddings::Embed; pub use one_or_many::{EmptyListError, OneOrMany}; #[cfg(feature = "derive")] diff --git a/rig-core/src/loaders/file.rs b/rig-core/src/loaders/file.rs index 17c2f1f3..84eb1864 100644 --- a/rig-core/src/loaders/file.rs +++ b/rig-core/src/loaders/file.rs @@ -162,7 +162,7 @@ impl<'a, T: 'a> FileLoader<'a, Result<T, FileLoaderError>> { } } -impl<'a> FileLoader<'a, Result<PathBuf, FileLoaderError>> { +impl FileLoader<'_, Result<PathBuf, FileLoaderError>> { /// Creates a new [FileLoader] using a glob pattern to match files. /// /// # Example @@ -227,7 +227,7 @@ impl<'a, T> IntoIterator for FileLoader<'a, T> { } } -impl<'a, T> Iterator for IntoIter<'a, T> { +impl<T> Iterator for IntoIter<'_, T> { type Item = T; fn next(&mut self) -> Option<Self::Item> { diff --git a/rig-core/src/loaders/pdf.rs b/rig-core/src/loaders/pdf.rs index ea18e4e6..410643f8 100644 --- a/rig-core/src/loaders/pdf.rs +++ b/rig-core/src/loaders/pdf.rs @@ -335,7 +335,7 @@ impl<'a, T: 'a> PdfFileLoader<'a, Result<T, PdfLoaderError>> { } } -impl<'a> PdfFileLoader<'a, Result<PathBuf, FileLoaderError>> { +impl PdfFileLoader<'_, Result<PathBuf, FileLoaderError>> { /// Creates a new [PdfFileLoader] using a glob pattern to match files. /// /// # Example @@ -396,7 +396,7 @@ impl<'a, T> IntoIterator for PdfFileLoader<'a, T> { } } -impl<'a, T> Iterator for IntoIter<'a, T> { +impl<T> Iterator for IntoIter<'_, T> { type Item = T; fn next(&mut self) -> Option<Self::Item> { diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 77d3c5a8..6b5a3079 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -219,7 +219,7 @@ pub struct EmbeddingData { pub index: usize, } -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct Usage { pub prompt_tokens: usize, pub total_tokens: usize, @@ -229,7 +229,7 @@ impl std::fmt::Display for Usage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "Prompt tokens: {}\nTotal tokens: {}", + "Prompt tokens: {} Total tokens: {}", self.prompt_tokens, self.total_tokens ) } @@ -535,7 +535,7 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "OpenAI completion token usage: {:?}", - response.usage + response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) ); response.try_into() } diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index 41be57f7..fa1e34fb 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -155,7 +155,7 @@ impl std::fmt::Display for Usage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "Prompt tokens: {}\nCompletion tokens: {}\nTotal tokens: {}", + "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}", self.prompt_tokens, self.completion_tokens, self.total_tokens ) } diff --git a/rig-core/tests/embed_macro.rs b/rig-core/tests/embed_macro.rs index 778b70bd..daf6dcf8 100644 --- a/rig-core/tests/embed_macro.rs +++ b/rig-core/tests/embed_macro.rs @@ -1,33 +1,38 @@ use rig::{ - embeddings::{embed::EmbedError, TextEmbedder}, - to_texts, Embed, + embeddings::{self, embed::EmbedError, TextEmbedder}, + Embed, }; use serde::Serialize; -fn serialize(embedder: &mut TextEmbedder, definition: Definition) -> Result<(), EmbedError> { - embedder.embed(serde_json::to_string(&definition).map_err(EmbedError::new)?); - - Ok(()) -} - -#[derive(Embed)] -struct FakeDefinition { - id: String, - word: String, - #[embed(embed_with = "serialize")] - definition: Definition, -} - -#[derive(Serialize, Clone)] -struct Definition { - word: String, - link: String, - speech: String, -} - #[test] fn test_custom_embed() { - let fake_definition = FakeDefinition { + #[derive(Embed)] + struct WordDefinition { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + word: String, + #[embed(embed_with = "custom_embedding_function")] + definition: Definition, + } + + #[derive(Serialize, Clone)] + struct Definition { + word: String, + link: String, + speech: String, + } + + fn custom_embedding_function( + embedder: &mut TextEmbedder, + definition: Definition, + ) -> Result<(), EmbedError> { + embedder.embed(serde_json::to_string(&definition).map_err(EmbedError::new)?); + + Ok(()) + } + + let definition = WordDefinition { id: "doc1".to_string(), word: "house".to_string(), definition: Definition { @@ -37,30 +42,41 @@ fn test_custom_embed() { }, }; - println!( - "FakeDefinition: {}, {}", - fake_definition.id, fake_definition.word - ); - assert_eq!( - to_texts(fake_definition).unwrap().first().unwrap().clone(), - "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() - + embeddings::to_texts(definition).unwrap(), + vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] ) } -#[derive(Embed)] -struct FakeDefinition2 { - id: String, - #[embed] - word: String, - #[embed(embed_with = "serialize")] - definition: Definition, -} - #[test] fn test_custom_and_basic_embed() { - let fake_definition = FakeDefinition2 { + #[derive(Embed)] + struct WordDefinition { + #[allow(dead_code)] + id: String, + #[embed] + word: String, + #[embed(embed_with = "custom_embedding_function")] + definition: Definition, + } + + #[derive(Serialize, Clone)] + struct Definition { + word: String, + link: String, + speech: String, + } + + fn custom_embedding_function( + embedder: &mut TextEmbedder, + definition: Definition, + ) -> Result<(), EmbedError> { + embedder.embed(serde_json::to_string(&definition).map_err(EmbedError::new)?); + + Ok(()) + } + + let definition = WordDefinition { id: "doc1".to_string(), word: "house".to_string(), definition: Definition { @@ -70,69 +86,63 @@ fn test_custom_and_basic_embed() { }, }; - println!( - "FakeDefinition: {}, {}", - fake_definition.id, fake_definition.word - ); - - let texts = to_texts(fake_definition).unwrap(); - - assert_eq!(texts.first().unwrap().clone(), "house".to_string()); + let texts = embeddings::to_texts(definition).unwrap(); assert_eq!( - texts.last().unwrap().clone(), - "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() - ) -} - -#[derive(Embed)] -struct FakeDefinition3 { - id: String, - word: String, - #[embed] - definition: String, + texts, + vec![ + "house".to_string(), + "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() + ] + ); } #[test] fn test_single_embed() { + #[derive(Embed)] + struct WordDefinition { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + word: String, + #[embed] + definition: String, + } + let definition = "a building in which people live; residence for human beings.".to_string(); - let fake_definition = FakeDefinition3 { + let word_definition = WordDefinition { id: "doc1".to_string(), word: "house".to_string(), definition: definition.clone(), }; - println!( - "FakeDefinition3: {}, {}", - fake_definition.id, fake_definition.word - ); assert_eq!( - to_texts(fake_definition).unwrap().first().unwrap().clone(), - definition + embeddings::to_texts(word_definition).unwrap(), + vec![definition] ) } -#[derive(Embed)] -struct Company { - id: String, - company: String, - #[embed] - employee_ages: Vec<i32>, -} - #[test] -fn test_multiple_embed_strings() { +fn test_embed_vec_non_string() { + #[derive(Embed)] + struct Company { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + company: String, + #[embed] + employee_ages: Vec<i32>, + } + let company = Company { id: "doc1".to_string(), company: "Google".to_string(), employee_ages: vec![25, 30, 35, 40], }; - println!("Company: {}, {}", company.id, company.company); - assert_eq!( - to_texts(company).unwrap(), + embeddings::to_texts(company).unwrap(), vec![ "25".to_string(), "30".to_string(), @@ -142,27 +152,60 @@ fn test_multiple_embed_strings() { ); } -#[derive(Embed)] -struct Company2 { - id: String, - #[embed] - company: String, - #[embed] - employee_ages: Vec<i32>, +#[test] +fn test_embed_vec_string() { + #[derive(Embed)] + struct Company { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + company: String, + #[embed] + employee_names: Vec<String>, + } + + let company = Company { + id: "doc1".to_string(), + company: "Google".to_string(), + employee_names: vec![ + "Alice".to_string(), + "Bob".to_string(), + "Charlie".to_string(), + "David".to_string(), + ], + }; + + assert_eq!( + embeddings::to_texts(company).unwrap(), + vec![ + "Alice".to_string(), + "Bob".to_string(), + "Charlie".to_string(), + "David".to_string() + ] + ); } #[test] fn test_multiple_embed_tags() { - let company = Company2 { + #[derive(Embed)] + struct Company { + #[allow(dead_code)] + id: String, + #[embed] + company: String, + #[embed] + employee_ages: Vec<i32>, + } + + let company = Company { id: "doc1".to_string(), company: "Google".to_string(), employee_ages: vec![25, 30, 35, 40], }; - println!("Company: {}", company.id); - assert_eq!( - to_texts(company).unwrap(), + embeddings::to_texts(company).unwrap(), vec![ "Google".to_string(), "25".to_string(), diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 954494e5..b12156fb 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -7,23 +7,23 @@ use rig::{Embed, OneOrMany}; use serde::Deserialize; #[derive(Embed, Clone, Deserialize, Debug)] -pub struct FakeDefinition { +pub struct WordDefinition { pub id: String, #[embed] pub definition: String, } -pub fn fake_definitions() -> Vec<FakeDefinition> { +pub fn fake_definitions() -> Vec<WordDefinition> { vec![ - FakeDefinition { + WordDefinition { id: "doc0".to_string(), definition: "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.".to_string() }, - FakeDefinition { + WordDefinition { id: "doc1".to_string(), definition: "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.".to_string() }, - FakeDefinition { + WordDefinition { id: "doc2".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() } @@ -46,22 +46,22 @@ pub fn schema(dims: usize) -> Schema { ])) } -// Convert FakeDefinition objects and their embedding to a RecordBatch. +// Convert WordDefinition objects and their embedding to a RecordBatch. pub fn as_record_batch( - records: Vec<(FakeDefinition, OneOrMany<Embedding>)>, + records: Vec<(WordDefinition, OneOrMany<Embedding>)>, dims: usize, ) -> Result<RecordBatch, lancedb::arrow::arrow_schema::ArrowError> { let id = StringArray::from_iter_values( records .iter() - .map(|(FakeDefinition { id, .. }, _)| id) + .map(|(WordDefinition { id, .. }, _)| id) .collect::<Vec<_>>(), ); let definition = StringArray::from_iter_values( records .iter() - .map(|(FakeDefinition { definition, .. }, _)| definition) + .map(|(WordDefinition { definition, .. }, _)| definition) .collect::<Vec<_>>(), ); diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 84679e3f..0b75f080 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; +use fixture::{as_record_batch, fake_definitions, schema, WordDefinition}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, @@ -31,7 +31,7 @@ async fn main() -> Result<(), anyhow::Error> { // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) - .map(|i| FakeDefinition { + .map(|i| WordDefinition { id: format!("doc{}", i), definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() }) @@ -65,7 +65,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store_index - .top_n::<FakeDefinition>("My boss says I zindle too much, what does that mean?", 1) + .top_n::<WordDefinition>("My boss says I zindle too much, what does that mean?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 160dfa10..8aca722b 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; +use fixture::{as_record_batch, fake_definitions, schema, WordDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, @@ -37,7 +37,7 @@ async fn main() -> Result<(), anyhow::Error> { // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) - .map(|i| FakeDefinition { + .map(|i| WordDefinition { id: format!("doc{}", i), definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() }) @@ -77,7 +77,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n::<FakeDefinition>("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) + .top_n::<WordDefinition>("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs index fd890280..7f9d1d7d 100644 --- a/rig-lancedb/src/utils/deserializer.rs +++ b/rig-lancedb/src/utils/deserializer.rs @@ -356,9 +356,9 @@ fn type_matcher(column: &Arc<dyn Array>) -> Result<Vec<Value>, VectorStoreError> } } -/////////////////////////////////////////////////////////////////////////////////// -/// Everything below includes helpers for the recursive function `type_matcher`./// -/////////////////////////////////////////////////////////////////////////////////// +// ================================================================ +// Everything below includes helpers for the recursive function `type_matcher` +// ================================================================ /// Trait used to "deserialize" an arrow_array::Array as as list of primitive objects. trait DeserializePrimitiveArray { diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index ecdb7a9f..5d0ed81b 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -11,7 +11,7 @@ use rig_mongodb::{MongoDbVectorIndex, SearchParams}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. #[derive(Embed, Clone, Deserialize, Debug)] -struct FakeDefinition { +struct WordDefinition { #[serde(rename = "_id")] id: String, #[embed] @@ -58,15 +58,15 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let fake_definitions = vec![ - FakeDefinition { + WordDefinition { id: "doc0".to_string(), definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), }, - FakeDefinition { + WordDefinition { id: "doc1".to_string(), definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), }, - FakeDefinition { + WordDefinition { id: "doc2".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), } @@ -80,7 +80,7 @@ async fn main() -> Result<(), anyhow::Error> { let mongo_documents = embeddings .iter() .map( - |(FakeDefinition { id, definition, .. }, embedding)| Document { + |(WordDefinition { id, definition, .. }, embedding)| Document { id: id.clone(), definition: definition.clone(), embedding: embedding.first().vec.clone(), @@ -101,7 +101,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = index - .top_n::<FakeDefinition>("What is a linglingdong?", 1) + .top_n::<WordDefinition>("What is a linglingdong?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-neo4j/examples/vector_search_simple.rs b/rig-neo4j/examples/vector_search_simple.rs index 0d3acf81..fca43d27 100644 --- a/rig-neo4j/examples/vector_search_simple.rs +++ b/rig-neo4j/examples/vector_search_simple.rs @@ -21,7 +21,7 @@ use rig_neo4j::{ }; #[derive(Embed, Clone, Debug)] -pub struct FakeDefinition { +pub struct WordDefinition { pub id: String, #[embed] pub definition: String, @@ -44,15 +44,15 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(model.clone()) - .document(FakeDefinition { + .document(WordDefinition { id: "doc0".to_string(), definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), })? - .document(FakeDefinition { + .document(WordDefinition { id: "doc1".to_string(), definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), })? - .document(FakeDefinition { + .document(WordDefinition { id: "doc2".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), })? diff --git a/rig-qdrant/examples/qdrant_vector_search.rs b/rig-qdrant/examples/qdrant_vector_search.rs index b1dc2d96..b1a91349 100644 --- a/rig-qdrant/examples/qdrant_vector_search.rs +++ b/rig-qdrant/examples/qdrant_vector_search.rs @@ -25,7 +25,7 @@ use rig_qdrant::QdrantVectorStore; use serde_json::json; #[derive(Embed)] -struct FakeDefinition { +struct WordDefinition { id: String, #[embed] definition: String, @@ -57,15 +57,15 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let documents = EmbeddingsBuilder::new(model.clone()) - .document(FakeDefinition { + .document(WordDefinition { id: "0981d983-a5f8-49eb-89ea-f7d3b2196d2e".to_string(), definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), })? - .document(FakeDefinition { + .document(WordDefinition { id: "62a36d43-80b6-4fd6-990c-f75bb02287d1".to_string(), definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), })? - .document(FakeDefinition { + .document(WordDefinition { id: "f9e17d59-32e5-440c-be02-b2759a654824".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), })? From a041699f792bb221ffaba1beb69c606ebd0f6df1 Mon Sep 17 00:00:00 2001 From: Christophe <cvauclair@protonmail.com> Date: Fri, 29 Nov 2024 10:28:20 -0500 Subject: [PATCH 85/91] style: Small renaming for consistency --- rig-core/src/embeddings/embedding.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index a7d5f89f..7c8877d9 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -42,17 +42,17 @@ pub trait EmbeddingModel: Clone + Sync + Send { /// Embed multiple text documents in a single request fn embed_texts( &self, - documents: impl IntoIterator<Item = String> + Send, + texts: impl IntoIterator<Item = String> + Send, ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send; /// Embed a single text document. fn embed_text( &self, - document: &str, + text: &str, ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send { async { Ok(self - .embed_texts(vec![document.to_string()]) + .embed_texts(vec![text.to_string()]) .await? .pop() .expect("There should be at least one embedding")) From bfc3291d599ae47e4822d563d615a235dbf6cec7 Mon Sep 17 00:00:00 2001 From: Christophe <cvauclair@protonmail.com> Date: Fri, 29 Nov 2024 10:55:16 -0500 Subject: [PATCH 86/91] docs: Improve docstrings --- rig-core/src/embeddings/builder.rs | 11 ++++++----- rig-core/src/embeddings/embed.rs | 26 +++++++++++++------------- rig-core/src/embeddings/tool.rs | 8 +++++--- 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 5a4b63c9..483ccb86 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -7,12 +7,12 @@ use std::{cmp::max, collections::HashMap}; use futures::{stream, StreamExt}; use crate::{ - embeddings::{Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, TextEmbedder}, + embeddings::{Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, embed::TextEmbedder}, OneOrMany, }; /// Builder for creating a collection of embeddings from a vector of documents of type `T`. -/// Accumulate documents such that they can be embedded in a single batch to limit api calls to the provider. +/// Accumulate documents such that they can be embedded in a single batch to limit api calls to the model provider. /// /// # Example /// ```rust @@ -86,7 +86,7 @@ impl<M: EmbeddingModel, T: Embed> EmbeddingsBuilder<M, T> { } } - /// Add a document that implements `Embed` to the builder. + /// Add a document to be embedded to the builder. `document` must implement the [Embed] trait. pub fn document(mut self, document: T) -> Result<Self, EmbedError> { let mut embedder = TextEmbedder::default(); document.embed(&mut embedder)?; @@ -96,7 +96,8 @@ impl<M: EmbeddingModel, T: Embed> EmbeddingsBuilder<M, T> { Ok(self) } - /// Add many documents that implement `Embed` to the builder. + /// Add multiple documents to be embedded to the builder. `documents` must be iteratable + /// with items that implement the [Embed] trait. pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> { let builder = documents .into_iter() @@ -168,7 +169,7 @@ impl<M: EmbeddingModel, T: Embed + Send> EmbeddingsBuilder<M, T> { #[cfg(test)] mod tests { use crate::{ - embeddings::{embed::EmbedError, Embedding, EmbeddingModel, TextEmbedder}, + embeddings::{embed::EmbedError, Embedding, EmbeddingModel, embed::TextEmbedder}, Embed, }; diff --git a/rig-core/src/embeddings/embed.rs b/rig-core/src/embeddings/embed.rs index 480a2930..659c38f8 100644 --- a/rig-core/src/embeddings/embed.rs +++ b/rig-core/src/embeddings/embed.rs @@ -1,16 +1,16 @@ //! The module defines the [Embed] trait, which must be implemented for types //! that can be embedded by the [crate::embeddings::EmbeddingsBuilder]. //! -//! The module also defines the [EmbedError] struct which is used for when the `embed` -//! method of the `Embed` trait fails. +//! The module also defines the [EmbedError] struct which is used for when the [Embed::embed] +//! method of the [Embed] trait fails. //! //! The module also defines the [TextEmbedder] struct which accumulates string values that need to be embedded. -//! It is used directly with the `Embed` trait. +//! It is used directly with the [Embed] trait. //! //! Finally, the module implements [Embed] for many common primitive types. -/// Error type used for when the `embed` method fo the `Embed` trait fails. -/// Used by default implementations of `Embed` for common types. +/// Error type used for when the [Embed::embed] method fo the [Embed] trait fails. +/// Used by default implementations of [Embed] for common types. #[derive(Debug, thiserror::Error)] #[error("{0}")] pub struct EmbedError(#[from] Box<dyn std::error::Error + Send + Sync>); @@ -22,8 +22,8 @@ impl EmbedError { } /// Derive this trait for objects that need to be converted to vector embeddings. -/// The `embed` method accumulates string values that need to be embedded by adding them to the `TextEmbedder`. -/// If an error occurs, the method should return `EmbedError`. +/// The [Embed::embed] method accumulates string values that need to be embedded by adding them to the [TextEmbedder]. +/// If an error occurs, the method should return [EmbedError]. /// # Example /// ```rust /// use std::env; @@ -41,7 +41,7 @@ impl EmbedError { /// fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { /// // Embeddings only need to be generated for `definition` field. /// // Split the definitions by comma and collect them into a vector of strings. -/// // That way, different embeddings can be generated for each definition in the definitions string. +/// // That way, different embeddings can be generated for each definition in the `definitions` string. /// self.definitions /// .split(",") /// .for_each(|s| { @@ -65,21 +65,21 @@ pub trait Embed { } /// Accumulates string values that need to be embedded. -/// Used by the `Embed` trait. +/// Used by the [Embed] trait. #[derive(Default)] pub struct TextEmbedder { pub(crate) texts: Vec<String>, } impl TextEmbedder { - /// Adds input `text` string to the list of texts in the `TextEmbedder` that need to be embedded. - pub fn embed(&mut self, text: String) { + /// Adds input `text` string to the list of texts in the [TextEmbedder] that need to be embedded. + pub(crate) fn embed(&mut self, text: String) { self.texts.push(text); } } -/// Client-side function to convert an object that implements the `Embed` trait to a vector of strings. -/// Similar to `serde`'s `serde_json::to_string()` function +/// Utility function that returns a vector of strings that need to be embedded for a +/// given object that implements the [Embed] trait. pub fn to_texts(item: impl Embed) -> Result<Vec<String>, EmbedError> { let mut embedder = TextEmbedder::default(); item.embed(&mut embedder)?; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index 0b45c66e..74f7e9b6 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -1,11 +1,12 @@ -//! The module defines the [ToolSchema] struct, which is used to embed an object that implements [crate::tool::Tool] +//! The module defines the [ToolSchema] struct, which is used to embed an object that implements [crate::tool::ToolEmbedding] use crate::{tool::ToolEmbeddingDyn, Embed}; use serde::Serialize; use super::embed::EmbedError; -/// Used by EmbeddingsBuilder to embed anything that implements ToolEmbedding. +/// Embeddable document that is used as an intermediate representation of a tool when +/// RAGging tools. #[derive(Clone, Serialize, Default, Eq, PartialEq)] pub struct ToolSchema { pub name: String, @@ -23,7 +24,8 @@ impl Embed for ToolSchema { } impl ToolSchema { - /// Convert item that implements ToolEmbeddingDyn to an ToolSchema. + /// Convert item that implements [ToolEmbeddingDyn] to an [ToolSchema]. + /// /// # Example /// ```rust /// use rig::{ From c21f5499d9eac6012fa1e489d2290ee69a012bc1 Mon Sep 17 00:00:00 2001 From: Christophe <cvauclair@protonmail.com> Date: Fri, 29 Nov 2024 10:56:39 -0500 Subject: [PATCH 87/91] style: fmt --- rig-core/src/embeddings/builder.rs | 8 +++++--- rig-core/src/embeddings/tool.rs | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 483ccb86..d427d35a 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -7,7 +7,9 @@ use std::{cmp::max, collections::HashMap}; use futures::{stream, StreamExt}; use crate::{ - embeddings::{Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, embed::TextEmbedder}, + embeddings::{ + embed::TextEmbedder, Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, + }, OneOrMany, }; @@ -96,7 +98,7 @@ impl<M: EmbeddingModel, T: Embed> EmbeddingsBuilder<M, T> { Ok(self) } - /// Add multiple documents to be embedded to the builder. `documents` must be iteratable + /// Add multiple documents to be embedded to the builder. `documents` must be iteratable /// with items that implement the [Embed] trait. pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> { let builder = documents @@ -169,7 +171,7 @@ impl<M: EmbeddingModel, T: Embed + Send> EmbeddingsBuilder<M, T> { #[cfg(test)] mod tests { use crate::{ - embeddings::{embed::EmbedError, Embedding, EmbeddingModel, embed::TextEmbedder}, + embeddings::{embed::EmbedError, embed::TextEmbedder, Embedding, EmbeddingModel}, Embed, }; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index 74f7e9b6..a8441a23 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -25,7 +25,7 @@ impl Embed for ToolSchema { impl ToolSchema { /// Convert item that implements [ToolEmbeddingDyn] to an [ToolSchema]. - /// + /// /// # Example /// ```rust /// use rig::{ From 1637c8a7aba61a826e6ecee360df422689fd83d0 Mon Sep 17 00:00:00 2001 From: Christophe <cvauclair@protonmail.com> Date: Fri, 29 Nov 2024 11:06:12 -0500 Subject: [PATCH 88/91] fix: `TextEmbedder::embed` visibility --- rig-core/src/embeddings/embed.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/src/embeddings/embed.rs b/rig-core/src/embeddings/embed.rs index 659c38f8..e8acc613 100644 --- a/rig-core/src/embeddings/embed.rs +++ b/rig-core/src/embeddings/embed.rs @@ -73,7 +73,7 @@ pub struct TextEmbedder { impl TextEmbedder { /// Adds input `text` string to the list of texts in the [TextEmbedder] that need to be embedded. - pub(crate) fn embed(&mut self, text: String) { + pub fn embed(&mut self, text: String) { self.texts.push(text); } } From ca4d9dd5e2c957ecb68d6e7b83d2cdb1ca1aacd1 Mon Sep 17 00:00:00 2001 From: Christophe <cvauclair@protonmail.com> Date: Fri, 29 Nov 2024 11:18:31 -0500 Subject: [PATCH 89/91] docs: Simplified the `EmbeddingsBuilder` docstring example to focus on the builder --- rig-core/src/embeddings/builder.rs | 49 +++++++----------------------- 1 file changed, 11 insertions(+), 38 deletions(-) diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index d427d35a..92760c13 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -13,8 +13,11 @@ use crate::{ OneOrMany, }; -/// Builder for creating a collection of embeddings from a vector of documents of type `T`. -/// Accumulate documents such that they can be embedded in a single batch to limit api calls to the model provider. +/// Builder for creating embeddings from one or more documents of type `T`. +/// Note: `T` can be any type that implements the [Embed] trait. +/// +/// Using the builder is preferred over using [EmbeddingModel::embed_text] directly as +/// it will batch the documents in a single request to the model provider. /// /// # Example /// ```rust @@ -23,21 +26,9 @@ use crate::{ /// use rig::{ /// embeddings::EmbeddingsBuilder, /// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, -/// vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, -/// Embed, /// }; /// use serde::{Deserialize, Serialize}; /// -/// // Shape of data that needs to be RAG'ed. -/// // The definition field will be used to generate embeddings. -/// #[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -/// struct WordDefinition { -/// id: String, -/// word: String, -/// #[embed] -/// definitions: Vec<String>, -/// } -/// /// // Create OpenAI client /// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); /// let openai_client = Client::new(&openai_api_key); @@ -46,30 +37,12 @@ use crate::{ /// /// let embeddings = EmbeddingsBuilder::new(model.clone()) /// .documents(vec![ -/// WordDefinition { -/// id: "doc0".to_string(), -/// word: "flurbo".to_string(), -/// definitions: vec![ -/// "A green alien that lives on cold planets.".to_string(), -/// "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() -/// ] -/// }, -/// WordDefinition { -/// id: "doc1".to_string(), -/// word: "glarb-glarb".to_string(), -/// definitions: vec![ -/// "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), -/// "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() -/// ] -/// }, -/// WordDefinition { -/// id: "doc2".to_string(), -/// word: "linglingdong".to_string(), -/// definitions: vec![ -/// "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), -/// "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() -/// ] -/// }, +/// "1. *flurbo* (noun): A green alien that lives on cold planets.".to_string(), +/// "2. *flurbo* (noun): A fictional digital currency that originated in the animated series Rick and Morty.".to_string() +/// "1. *glarb-glarb* (noun): An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), +/// "2. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() +/// "1. *linlingdong* (noun): A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), +/// "2. *linlingdong* (noun): A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() /// ])? /// .build() /// .await?; From f760a5fd31ac0360748e31817baa07962bba0f57 Mon Sep 17 00:00:00 2001 From: Christophe <cvauclair@protonmail.com> Date: Fri, 29 Nov 2024 11:25:09 -0500 Subject: [PATCH 90/91] style: cargo fmt --- rig-core/src/embeddings/builder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 92760c13..f9e80779 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -16,7 +16,7 @@ use crate::{ /// Builder for creating embeddings from one or more documents of type `T`. /// Note: `T` can be any type that implements the [Embed] trait. /// -/// Using the builder is preferred over using [EmbeddingModel::embed_text] directly as +/// Using the builder is preferred over using [EmbeddingModel::embed_text] directly as /// it will batch the documents in a single request to the model provider. /// /// # Example From f21da45a707738e7b2bb9641bde254f4dd86f5bf Mon Sep 17 00:00:00 2001 From: Christophe <cvauclair@protonmail.com> Date: Fri, 29 Nov 2024 11:25:26 -0500 Subject: [PATCH 91/91] docs: Small edit to lancedb examples --- rig-lancedb/examples/fixtures/lib.rs | 2 +- rig-lancedb/examples/vector_search_local_ann.rs | 4 ++-- rig-lancedb/examples/vector_search_local_enn.rs | 4 ++-- rig-lancedb/examples/vector_search_s3_ann.rs | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index b12156fb..94704822 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -13,7 +13,7 @@ pub struct WordDefinition { pub definition: String, } -pub fn fake_definitions() -> Vec<WordDefinition> { +pub fn word_definitions() -> Vec<WordDefinition> { vec![ WordDefinition { id: "doc0".to_string(), diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 0b75f080..03636089 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema, WordDefinition}; +use fixture::{as_record_batch, schema, word_definitions, WordDefinition}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, @@ -27,7 +27,7 @@ async fn main() -> Result<(), anyhow::Error> { // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions())? + .documents(word_definitions())? // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 859442be..0244d33e 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema}; +use fixture::{as_record_batch, schema, word_definitions}; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, @@ -23,7 +23,7 @@ async fn main() -> Result<(), anyhow::Error> { // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions())? + .documents(word_definitions())? .build() .await?; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 8aca722b..f296d1d7 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema, WordDefinition}; +use fixture::{as_record_batch, schema, word_definitions, WordDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, @@ -33,7 +33,7 @@ async fn main() -> Result<(), anyhow::Error> { // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions())? + .documents(word_definitions())? // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256)