From b7d146ff57f3ac2ac279834dab432c993e8baf7b Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 2 Oct 2024 18:07:02 -0400 Subject: [PATCH 01/47] feat: setup derive macro --- Cargo.lock | 53 ++++++++++++++++--------- Cargo.toml | 2 +- rig-macros/Cargo.toml | 8 ++++ rig-macros/rig-macros-derive/Cargo.toml | 11 +++++ rig-macros/rig-macros-derive/src/lib.rs | 52 ++++++++++++++++++++++++ rig-macros/src/lib.rs | 34 ++++++++++++++++ 6 files changed, 141 insertions(+), 19 deletions(-) create mode 100644 rig-macros/Cargo.toml create mode 100644 rig-macros/rig-macros-derive/Cargo.toml create mode 100644 rig-macros/rig-macros-derive/src/lib.rs create mode 100644 rig-macros/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 80967e9d..2da34761 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -59,7 +59,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -458,7 +458,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -986,7 +986,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -1107,9 +1107,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -1226,6 +1226,22 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "rig-macros" +version = "0.1.0" +dependencies = [ + "rig-macros-derive", + "serde_json", +] + +[[package]] +name = "rig-macros-derive" +version = "0.1.0" +dependencies = [ + "quote", + "syn 2.0.79", +] + [[package]] name = "rig-mongodb" version = "0.1.2" @@ -1369,7 +1385,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -1458,7 +1474,7 @@ checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -1469,17 +1485,18 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] name = "serde_json" -version = "1.0.117" +version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ "indexmap", "itoa", + "memchr", "ryu", "serde", ] @@ -1635,9 +1652,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.65" +version = "2.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2863d96a84c6439701d7a38f9de935ec562c8832cc55d1dde0f513b52fad106" +checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" dependencies = [ "proc-macro2", "quote", @@ -1712,7 +1729,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -1798,7 +1815,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -1860,7 +1877,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] [[package]] @@ -2068,7 +2085,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", "wasm-bindgen-shared", ] @@ -2102,7 +2119,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2341,5 +2358,5 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.79", ] diff --git a/Cargo.toml b/Cargo.toml index 2501b86f..d3a0c372 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" members = [ - "rig-core", + "rig-core", "rig-macros", "rig-macros/rig-macros-derive", "rig-mongodb", ] diff --git a/rig-macros/Cargo.toml b/rig-macros/Cargo.toml new file mode 100644 index 00000000..729582f1 --- /dev/null +++ b/rig-macros/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "rig-macros" +version = "0.1.0" +edition = "2021" + +[dependencies] +serde_json = "1.0.128" +rig-macros-derive = { path = "./rig-macros-derive" } diff --git a/rig-macros/rig-macros-derive/Cargo.toml b/rig-macros/rig-macros-derive/Cargo.toml new file mode 100644 index 00000000..51807c28 --- /dev/null +++ b/rig-macros/rig-macros-derive/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "rig-macros-derive" +version = "0.1.0" +edition = "2021" + +[dependencies] +quote = "1.0.37" +syn = "2.0.79" + +[lib] +proc-macro = true diff --git a/rig-macros/rig-macros-derive/src/lib.rs b/rig-macros/rig-macros-derive/src/lib.rs new file mode 100644 index 00000000..356546a8 --- /dev/null +++ b/rig-macros/rig-macros-derive/src/lib.rs @@ -0,0 +1,52 @@ +extern crate proc_macro; +use proc_macro::TokenStream; +use quote::quote; +use syn::{Attribute, Meta}; + +// https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro +// https://doc.rust-lang.org/reference/procedural-macros.html + +#[proc_macro_derive(Embedding, attributes(embed))] +pub fn derive_embed_trait(item: TokenStream) -> TokenStream { + let ast = syn::parse(item).unwrap(); + + impl_embeddable_macro(&ast) +} + +fn impl_embeddable_macro(ast: &syn::DeriveInput) -> TokenStream { + let name = &ast.ident; + + match &ast.data { + syn::Data::Struct(data_struct) => { + let field_to_embed = data_struct.fields.clone().into_iter().find(|field| { + field + .attrs + .clone() + .into_iter() + .find(|attribute| match attribute { + Attribute { + meta: Meta::Path(path), + .. + } => match path.get_ident() { + Some(attribute_name) => attribute_name == "embed", + None => false, + }, + _ => return false, + }) + .is_some() + }); + } + _ => {} + }; + + let gen = quote! { + impl Embeddable for #name { + type Kind = String; + + fn embeddable(&self) { + println!("{}", stringify!(#name)); + } + } + }; + gen.into() +} diff --git a/rig-macros/src/lib.rs b/rig-macros/src/lib.rs new file mode 100644 index 00000000..0512c177 --- /dev/null +++ b/rig-macros/src/lib.rs @@ -0,0 +1,34 @@ +enum Kind { + Single, + Many, +} + +trait Embeddable { + type Kind; + fn embeddable(&self); +} + +#[cfg(test)] +mod tests { + use super::Embeddable; + use rig_macros_derive::Embedding; + + #[derive(Embedding)] + struct MyStruct { + id: String, + #[embed] + name: String, + } + + #[test] + fn test_macro() { + let my_struct = MyStruct { + id: "1".to_string(), + name: "John".to_string(), + }; + + my_struct.embeddable(); + + assert!(false) + } +} From 5904734d4fe0f51a324fdea1465fee3683e26add Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 3 Oct 2024 16:13:16 -0400 Subject: [PATCH 02/47] test: test out writing embeddable macro --- rig-macros/rig-macros-derive/src/lib.rs | 61 ++++++---- rig-macros/src/lib.rs | 27 ++++- rig-macros/test.rs | 150 ++++++++++++++++++++++++ 3 files changed, 207 insertions(+), 31 deletions(-) create mode 100644 rig-macros/test.rs diff --git a/rig-macros/rig-macros-derive/src/lib.rs b/rig-macros/rig-macros-derive/src/lib.rs index 356546a8..2cf7f251 100644 --- a/rig-macros/rig-macros-derive/src/lib.rs +++ b/rig-macros/rig-macros-derive/src/lib.rs @@ -1,50 +1,59 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::quote; -use syn::{Attribute, Meta}; +use syn::{parse_macro_input, Attribute, DeriveInput, Meta}; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html #[proc_macro_derive(Embedding, attributes(embed))] pub fn derive_embed_trait(item: TokenStream) -> TokenStream { - let ast = syn::parse(item).unwrap(); + let input = parse_macro_input!(item as DeriveInput); - impl_embeddable_macro(&ast) + impl_embeddable_macro(&input) } -fn impl_embeddable_macro(ast: &syn::DeriveInput) -> TokenStream { - let name = &ast.ident; +fn impl_embeddable_macro(input: &syn::DeriveInput) -> TokenStream { + let name = &input.ident; - match &ast.data { + let embeddings = match &input.data { syn::Data::Struct(data_struct) => { - let field_to_embed = data_struct.fields.clone().into_iter().find(|field| { - field - .attrs - .clone() - .into_iter() - .find(|attribute| match attribute { - Attribute { - meta: Meta::Path(path), - .. - } => match path.get_ident() { - Some(attribute_name) => attribute_name == "embed", - None => false, - }, - _ => return false, - }) - .is_some() - }); + data_struct.fields.clone().into_iter().filter(|field| { + field + .attrs + .clone() + .into_iter() + .any(|attribute| match attribute { + Attribute { + meta: Meta::Path(path), + .. + } => match path.get_ident() { + Some(attribute_name) => attribute_name == "embed", + None => false + } + _ => false, + }) + }).map(|field| { + let field_name = field.ident.expect(""); + + quote! { + self.#field_name.embeddable() + } + + }).collect::>() } - _ => {} + _ => vec![] }; let gen = quote! { impl Embeddable for #name { type Kind = String; - fn embeddable(&self) { - println!("{}", stringify!(#name)); + fn embeddable(&self) -> Vec { + vec![ + #(#embeddings),* + ].into_iter().flatten().collect() + } } }; diff --git a/rig-macros/src/lib.rs b/rig-macros/src/lib.rs index 0512c177..fedf2cde 100644 --- a/rig-macros/src/lib.rs +++ b/rig-macros/src/lib.rs @@ -5,17 +5,34 @@ enum Kind { trait Embeddable { type Kind; - fn embeddable(&self); + fn embeddable(&self) -> Vec; } #[cfg(test)] mod tests { - use super::Embeddable; + use super::{Embeddable, Kind}; use rig_macros_derive::Embedding; + impl Embeddable for usize { + type Kind = Kind; + + fn embeddable(&self) -> Vec { + vec![self.to_string()] + } + } + + impl Embeddable for String { + type Kind = Kind; + + fn embeddable(&self) -> Vec { + vec![self.clone()] + } + } + #[derive(Embedding)] struct MyStruct { - id: String, + #[embed] + id: usize, #[embed] name: String, } @@ -23,11 +40,11 @@ mod tests { #[test] fn test_macro() { let my_struct = MyStruct { - id: "1".to_string(), + id: 1, name: "John".to_string(), }; - my_struct.embeddable(); + println!("{:?}", my_struct.embeddable()); assert!(false) } diff --git a/rig-macros/test.rs b/rig-macros/test.rs new file mode 100644 index 00000000..4f9e80e9 --- /dev/null +++ b/rig-macros/test.rs @@ -0,0 +1,150 @@ +/// Builder for creating a collection of embeddings +pub struct EmbeddingsBuilder { + model: M, + documents: Vec<(T, Vec)>, +} + +trait Embeddable { + // Return list of strings that need to be embedded. + // Instead of Vec, should be Vec + fn embeddable(&self) -> Vec; +} + +type EmbeddingVector = Vec; + +impl EmbeddingsBuilder { + /// Create a new embedding builder with the given embedding model + pub fn new(model: M) -> Self { + Self { + model, + documents: vec![], + } + } + + pub fn add( + mut self, + document: T, + ) -> Self { + let embed_documents: Vec = document.embeddable(); + + self.documents.push(( + document, + embed_documents, + )); + self + } + + pub fn build(&self) -> Result)>, EmbeddingError> { + self.documents.iter().map(|(doc, values_to_embed)| { + values_to_embed.iter().map(|value| { + let value_str = serde_json::to_string(value)?; + generate_embedding(value_str) + }) + }) + } + + pub fn build_simple(&self) -> Result, EmbeddingError> { + self.documents.iter().map(|(doc, value_to_embed)| { + let value_str = serde_json::to_string(value_to_embed)?; + generate_embedding(value_str) + }) + } +} + + +// Example +#[derive(Embeddable)] +struct DictionaryEntry { + word: String, + #[embed] + definitions: String, +} + +#[derive(Embeddable)] +struct MetadataEmbedding { + pub id: String, + #[embed(with = serde_json::to_value)] + pub content: CategoryMetadata, + pub created: Option>, + pub modified: Option>, + pub dataset_ids: Vec, +} + +#[derive(serde::Serialize)] +struct CategoryMetadata { + pub name: String, + pub description: String, + pub tags: Vec, + pub links: Vec, +} + +// Inside macro: +impl Embeddable for DictionaryEntry { + fn embeddable(&self) -> Vec { + // Find the field tagged with #[embed] and return its value + // If there are no embedding tags, return the entire struct + } +} + +fn main() { + let embeddings: Vec<(DictionaryEntry, Vec)> = EmbeddingsBuilder::new(model.clone()) + .add(DictionaryEntry::new("blah", vec!["definition of blah"])) + .add(DictionaryEntry::new("foo", vec!["definition of foo"])) + .build()?; + + // In relational vector store like LanceDB, need to flatten result (create row for each item in definitions vector): + // Column: word (string) + // Column: definition (vector) + + // In document vector store like MongoDB, might need to merge the vector results back with their corresponding definition string: + // Field: word (string) + // Field: definitions + // // Field: definition (string) + // // Field: vector + + Ok(()) +} + + + +// Iterations: +// 1 - Multiple fields to embed? +#[derive(Embedding)] +struct DictionaryEntry { + word: String, + #[embed] + definitions: Vec, + #[embed] + synonyms: Vec +} + +// 2 - Embed recursion? Ex: +#[derive(Embedding)] +struct DictionaryEntry { + word: String, + #[embed] + definitions: Vec, +} +struct Definition { + definition: String, + #[embed] + links: Vec +} + +// { +// word: "blah", +// definitions: [ +// { +// definition: "definition of blah", +// links: ["link1", "link2"] +// }, +// { +// definition: "another definition for blah", +// links: ["link3"] +// } +// ] +// } + +// blah | definition of blah | link1 | embedding for link1 +// blah | definition of blah | link2 | embedding for link2 +// blah | another definition for blah | link3 | embedding for link3 \ No newline at end of file From ee9b5c332b322621395757ef4708e010776a8362 Mon Sep 17 00:00:00 2001 From: Garance Date: Fri, 4 Oct 2024 17:23:39 -0400 Subject: [PATCH 03/47] test: continue testing custom macro implementation --- Cargo.lock | 11 +- rig-macros/rig-macros-derive/Cargo.toml | 3 +- rig-macros/rig-macros-derive/src/lib.rs | 129 ++++++++++++++++++------ rig-macros/src/lib.rs | 6 +- rig-macros/test.rs | 12 +-- 5 files changed, 119 insertions(+), 42 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2da34761..e62e3101 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -712,6 +712,12 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + [[package]] name = "ipconfig" version = "0.3.2" @@ -1092,9 +1098,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.83" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b33eb56c327dec362a9e55b3ad14f9d2f0904fb5a5b03b513ab5465399e9f43" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" dependencies = [ "unicode-ident", ] @@ -1238,6 +1244,7 @@ dependencies = [ name = "rig-macros-derive" version = "0.1.0" dependencies = [ + "indoc", "quote", "syn 2.0.79", ] diff --git a/rig-macros/rig-macros-derive/Cargo.toml b/rig-macros/rig-macros-derive/Cargo.toml index 51807c28..f2f22f6d 100644 --- a/rig-macros/rig-macros-derive/Cargo.toml +++ b/rig-macros/rig-macros-derive/Cargo.toml @@ -4,8 +4,9 @@ version = "0.1.0" edition = "2021" [dependencies] +indoc = "2.0.5" quote = "1.0.37" -syn = "2.0.79" +syn = { version = "2.0.79", features = ["full"]} [lib] proc-macro = true diff --git a/rig-macros/rig-macros-derive/src/lib.rs b/rig-macros/rig-macros-derive/src/lib.rs index 2cf7f251..f9a5bd57 100644 --- a/rig-macros/rig-macros-derive/src/lib.rs +++ b/rig-macros/rig-macros-derive/src/lib.rs @@ -1,11 +1,17 @@ extern crate proc_macro; +use indoc::indoc; use proc_macro::TokenStream; -use quote::quote; -use syn::{parse_macro_input, Attribute, DeriveInput, Meta}; +use quote::{quote, ToTokens}; +use syn::{ + meta::ParseNestedMeta, parse_macro_input, spanned::Spanned, Attribute, DataStruct, DeriveInput, Meta, ExprPath +}; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html +const EMBED: &str = "embed"; +const EMBED_WITH: &str = "embed_with"; + #[proc_macro_derive(Embedding, attributes(embed))] pub fn derive_embed_trait(item: TokenStream) -> TokenStream { let input = parse_macro_input!(item as DeriveInput); @@ -18,44 +24,105 @@ fn impl_embeddable_macro(input: &syn::DeriveInput) -> TokenStream { let embeddings = match &input.data { syn::Data::Struct(data_struct) => { - data_struct.fields.clone().into_iter().filter(|field| { - field - .attrs - .clone() - .into_iter() - .any(|attribute| match attribute { - Attribute { - meta: Meta::Path(path), - .. - } => match path.get_ident() { - Some(attribute_name) => attribute_name == "embed", - None => false - } - _ => false, - }) - }).map(|field| { - let field_name = field.ident.expect(""); - - quote! { - self.#field_name.embeddable() - } - - }).collect::>() + // let invoke_trait = invoke_trait(data_struct) + // .map(|field_name| { + // quote! { + // self.#field_name.embeddable() + // } + // }) + // .collect::>(); + custom_trait_implementation(data_struct) } - _ => vec![] - }; + _ => Ok(false), + } + .unwrap(); let gen = quote! { impl Embeddable for #name { type Kind = String; fn embeddable(&self) -> Vec { - vec![ - #(#embeddings),* - ].into_iter().flatten().collect() - + // vec![ + // #(#embeddings),* + // ].into_iter().flatten().collect() + println!("{}", #embeddings); + vec![] } } }; gen.into() } + +fn custom_trait_implementation(data_struct: &DataStruct) -> Result { + let t = data_struct + .fields + .clone() + .into_iter() + .for_each(|field| { + let _t = field.attrs.clone().into_iter().map(|attr| { + let t = if attr.path().is_ident(EMBED) { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident(EMBED_WITH) { + let path = parse_embed_with(&meta)?; + + let tokens = meta.path.into_token_stream(); + }; + Ok(()) + }) + } else { + todo!() + }; + }).collect::>(); + }); + Ok(false) +} + +fn parse_embed_with(meta: &ParseNestedMeta) -> Result { + // #[embed(embed_with = "...")] + let expr = meta.value().unwrap().parse::().unwrap(); + let mut value = &expr; + while let syn::Expr::Group(e) = value { + value = &e.expr; + } + let string = if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit_str), + .. + }) = value + { + let suffix = lit_str.suffix(); + if !suffix.is_empty() { + return Err(syn::Error::new( + lit_str.span(), + format!("unexpected suffix `{}` on string literal", suffix) + )) + } + lit_str.clone() + } else { + return Err(syn::Error::new( + value.span(), + format!("expected {} attribute to be a string: `{} = \"...\"`", EMBED_WITH, EMBED_WITH) + )) + }; + + string.parse() +} + +fn invoke_trait(data_struct: &DataStruct) -> impl Iterator { + data_struct.fields.clone().into_iter().filter_map(|field| { + let found_embed = field + .attrs + .clone() + .into_iter() + .any(|attribute| match attribute { + Attribute { + meta: Meta::Path(path), + .. + } => path.is_ident("embed"), + _ => false, + }); + match found_embed { + true => Some(field.ident.expect("")), + false => None, + } + }) +} diff --git a/rig-macros/src/lib.rs b/rig-macros/src/lib.rs index fedf2cde..628fe6d2 100644 --- a/rig-macros/src/lib.rs +++ b/rig-macros/src/lib.rs @@ -33,7 +33,7 @@ mod tests { struct MyStruct { #[embed] id: usize, - #[embed] + #[embed(embed_with = "something")] name: String, } @@ -44,7 +44,9 @@ mod tests { name: "John".to_string(), }; - println!("{:?}", my_struct.embeddable()); + my_struct.embeddable(); + + // println!("{:?}", my_struct.embeddable()); assert!(false) } diff --git a/rig-macros/test.rs b/rig-macros/test.rs index 4f9e80e9..f230af56 100644 --- a/rig-macros/test.rs +++ b/rig-macros/test.rs @@ -1,18 +1,18 @@ /// Builder for creating a collection of embeddings -pub struct EmbeddingsBuilder { +pub struct EmbeddingsBuilder { model: M, - documents: Vec<(T, Vec)>, + documents: Vec<(T, Vec)>, } -trait Embeddable { +trait Embeddable { // Return list of strings that need to be embedded. // Instead of Vec, should be Vec - fn embeddable(&self) -> Vec; + fn embeddable(&self) -> Vec; } type EmbeddingVector = Vec; -impl EmbeddingsBuilder { +impl EmbeddingsBuilder { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { @@ -25,7 +25,7 @@ impl EmbeddingsBuilder mut self, document: T, ) -> Self { - let embed_documents: Vec = document.embeddable(); + let embed_documents: Vec = document.embeddable(); self.documents.push(( document, From c8c1e9ca0c38bba8fc1388f07377e4526f549963 Mon Sep 17 00:00:00 2001 From: Garance Date: Sat, 5 Oct 2024 23:24:02 -0400 Subject: [PATCH 04/47] feat: macro generate trait bounds --- rig-macros/rig-macros-derive/src/lib.rs | 74 +++++++++++++++++-------- rig-macros/src/lib.rs | 8 --- 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/rig-macros/rig-macros-derive/src/lib.rs b/rig-macros/rig-macros-derive/src/lib.rs index f9a5bd57..2a6e4aaa 100644 --- a/rig-macros/rig-macros-derive/src/lib.rs +++ b/rig-macros/rig-macros-derive/src/lib.rs @@ -3,7 +3,7 @@ use indoc::indoc; use proc_macro::TokenStream; use quote::{quote, ToTokens}; use syn::{ - meta::ParseNestedMeta, parse_macro_input, spanned::Spanned, Attribute, DataStruct, DeriveInput, Meta, ExprPath + meta::ParseNestedMeta, parse_macro_input, parse_quote, spanned::Spanned, Attribute, DataStruct, DeriveInput, ExprPath, Meta, Path }; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro @@ -14,31 +14,39 @@ const EMBED_WITH: &str = "embed_with"; #[proc_macro_derive(Embedding, attributes(embed))] pub fn derive_embed_trait(item: TokenStream) -> TokenStream { - let input = parse_macro_input!(item as DeriveInput); + let mut input = parse_macro_input!(item as DeriveInput); - impl_embeddable_macro(&input) + impl_embeddable_macro(&mut input) } -fn impl_embeddable_macro(input: &syn::DeriveInput) -> TokenStream { +fn impl_embeddable_macro(input: &mut syn::DeriveInput) -> TokenStream { let name = &input.ident; let embeddings = match &input.data { syn::Data::Struct(data_struct) => { - // let invoke_trait = invoke_trait(data_struct) + basic_embed_fields(data_struct).for_each(|(_, field_type)| { + add_struct_bounds(&mut input.generics, &field_type) + }); + + // let basic_embed_fields = basic_embed_fields(data_struct) // .map(|field_name| { // quote! { // self.#field_name.embeddable() // } // }) // .collect::>(); - custom_trait_implementation(data_struct) + + let func_names = custom_trait_implementation(data_struct).unwrap(); + + false } - _ => Ok(false), - } - .unwrap(); + _ => false, + }; + + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { - impl Embeddable for #name { + impl #impl_generics Embeddable for #name #ty_generics #where_clause { type Kind = String; fn embeddable(&self) -> Vec { @@ -50,31 +58,37 @@ fn impl_embeddable_macro(input: &syn::DeriveInput) -> TokenStream { } } }; + eprintln!("Generated code:\n{}", gen); + gen.into() } -fn custom_trait_implementation(data_struct: &DataStruct) -> Result { - let t = data_struct +fn custom_trait_implementation(data_struct: &DataStruct) -> Result, syn::Error> { + Ok(data_struct .fields .clone() .into_iter() - .for_each(|field| { - let _t = field.attrs.clone().into_iter().map(|attr| { - let t = if attr.path().is_ident(EMBED) { + .map(|field| { + let mut path = None; + field.attrs.clone().into_iter().map(|attr| { + if attr.path().is_ident(EMBED) { attr.parse_nested_meta(|meta| { if meta.path.is_ident(EMBED_WITH) { - let path = parse_embed_with(&meta)?; + path = Some(parse_embed_with(&meta)?); - let tokens = meta.path.into_token_stream(); + // let tokens = meta.path.into_token_stream(); }; Ok(()) }) } else { - todo!() - }; - }).collect::>(); - }); - Ok(false) + Ok(()) + } + }).collect::,_>>()?; + Ok::<_, syn::Error>(path) + }).collect::,_>>()? + .into_iter() + .filter_map(|i| i) + .collect()) } fn parse_embed_with(meta: &ParseNestedMeta) -> Result { @@ -107,7 +121,16 @@ fn parse_embed_with(meta: &ParseNestedMeta) -> Result { string.parse() } -fn invoke_trait(data_struct: &DataStruct) -> impl Iterator { +fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { + let where_clause = generics.make_where_clause(); + + where_clause.predicates.push(parse_quote! { + #field_type: Embeddable + }); +} + + +fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator { data_struct.fields.clone().into_iter().filter_map(|field| { let found_embed = field .attrs @@ -121,7 +144,10 @@ fn invoke_trait(data_struct: &DataStruct) -> impl Iterator { _ => false, }); match found_embed { - true => Some(field.ident.expect("")), + true => Some(( + field.ident.expect(""), + field.ty + )), false => None, } }) diff --git a/rig-macros/src/lib.rs b/rig-macros/src/lib.rs index 628fe6d2..b57bf150 100644 --- a/rig-macros/src/lib.rs +++ b/rig-macros/src/lib.rs @@ -21,14 +21,6 @@ mod tests { } } - impl Embeddable for String { - type Kind = Kind; - - fn embeddable(&self) -> Vec { - vec![self.clone()] - } - } - #[derive(Embedding)] struct MyStruct { #[embed] From dff0aebb32986ae3df32ce748987c3272e6c9aa6 Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 7 Oct 2024 14:51:07 -0400 Subject: [PATCH 05/47] refactor: split up macro into multiple files --- Cargo.lock | 9 +- rig-macros/Cargo.toml | 1 + rig-macros/rig-macros-derive/src/embedding.rs | 216 ++++++++++++++++++ rig-macros/rig-macros-derive/src/lib.rs | 147 +----------- rig-macros/src/lib.rs | 32 ++- rig-macros/test.rs | 39 ++-- 6 files changed, 275 insertions(+), 169 deletions(-) create mode 100644 rig-macros/rig-macros-derive/src/embedding.rs diff --git a/Cargo.lock b/Cargo.lock index e62e3101..48f46588 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1237,6 +1237,7 @@ name = "rig-macros" version = "0.1.0" dependencies = [ "rig-macros-derive", + "serde", "serde_json", ] @@ -1457,9 +1458,9 @@ checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" [[package]] name = "serde" -version = "1.0.203" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" dependencies = [ "serde_derive", ] @@ -1475,9 +1476,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.203" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" dependencies = [ "proc-macro2", "quote", diff --git a/rig-macros/Cargo.toml b/rig-macros/Cargo.toml index 729582f1..d1cd2cd9 100644 --- a/rig-macros/Cargo.toml +++ b/rig-macros/Cargo.toml @@ -6,3 +6,4 @@ edition = "2021" [dependencies] serde_json = "1.0.128" rig-macros-derive = { path = "./rig-macros-derive" } +serde = {version = "1.0.210", features = ["derive"]} diff --git a/rig-macros/rig-macros-derive/src/embedding.rs b/rig-macros/rig-macros-derive/src/embedding.rs new file mode 100644 index 00000000..2bfdb0ab --- /dev/null +++ b/rig-macros/rig-macros-derive/src/embedding.rs @@ -0,0 +1,216 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{ + meta::ParseNestedMeta, parse_quote, punctuated::Punctuated, spanned::Spanned, Attribute, + DataStruct, ExprPath, Meta, Token, +}; + +const EMBED: &str = "embed"; +const EMBED_WITH: &str = "embed_with"; + +pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { + let name = &input.ident; + + let func_calls = + match &input.data { + syn::Data::Struct(data_struct) => { + // Handles fields tagged with #[embed] + let mut function_calls = data_struct + .basic_embed_fields() + .map(|field| { + add_struct_bounds(&mut input.generics, &field.ty); + + let field_name = field.ident; + quote! { + self.#field_name.embeddable() + } + }) + .collect::>(); + + // Handles fields tagged with #[embed(embed_with = "...")] + function_calls.extend(data_struct.custom_embed_fields().unwrap().map( + |(field, _)| { + let field_name = field.ident; + + quote! { + embeddable(&self.#field_name) + } + }, + )); + + function_calls + } + _ => vec![], + }; + + // Import the paths to the custom functions. + let custom_func_paths = match &input.data { + syn::Data::Struct(data_struct) => data_struct + .custom_embed_fields() + .unwrap() + .map(|(_, custom_func_path)| { + quote! { + use #custom_func_path::embeddable; + } + }) + .collect::>(), + _ => vec![], + }; + + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + + let gen = quote! { + #(#custom_func_paths);* + + impl #impl_generics Embeddable for #name #ty_generics #where_clause { + type Kind = String; + + fn embeddable(&self) -> Vec { + vec![ + #(#func_calls),* + ].into_iter().flatten().collect() + } + } + }; + eprintln!("Generated code:\n{}", gen); + + gen.into() +} + +// Adds bounds to where clause that force all fields tagged with #[embed] to implement the Embeddable trait. +fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { + let where_clause = generics.make_where_clause(); + + where_clause.predicates.push(parse_quote! { + #field_type: Embeddable + }); +} + +trait AttributeParser { + /// Finds and returns fields with simple #[embed] attribute tags only. + fn basic_embed_fields(&self) -> impl Iterator; + /// Finds and returns fields with #[embed(embed_with = "...")] attribute tags only. + /// Also returns the attribute in question. + fn custom_embed_fields( + &self, + ) -> Result, syn::Error>; +} + +impl AttributeParser for DataStruct { + fn basic_embed_fields(&self) -> impl Iterator { + self.fields.clone().into_iter().filter(|field| { + field + .attrs + .clone() + .into_iter() + .any(|attribute| match attribute { + Attribute { + meta: Meta::Path(path), + .. + } => path.is_ident(EMBED), + _ => false, + }) + }) + } + + fn custom_embed_fields( + &self, + ) -> Result, syn::Error> { + // Determine if field is tagged with #[embed(embed_with = "...")] attribute. + fn is_custom_embed(attribute: &syn::Attribute) -> Result { + let is_custom_embed = match attribute.meta { + Meta::List(_) => attribute + .parse_args_with(Punctuated::::parse_terminated)? + .into_iter() + .any(|meta| meta.path().is_ident(EMBED_WITH)), + _ => false, + }; + + Ok(attribute.path().is_ident(EMBED) && is_custom_embed) + } + + // Get the "..." part of the #[embed(embed_with = "...")] attribute. + // Ex: If attribute is tagged with #[embed(embed_with = "my_embed")], returns "my_embed". + fn expand_tag(attribute: &syn::Attribute) -> Result { + let mut custom_func_path = None; + + attribute.parse_nested_meta(|meta| { + custom_func_path = Some(meta.function_path()?); + Ok(()) + })?; + + match custom_func_path { + Some(path) => Ok(path), + None => Err(syn::Error::new( + attribute.span(), + format!( + "expected {} attribute to have format: `#[embed(embed_with = \"...\")]`", + EMBED_WITH + ), + )), + } + } + + Ok(self + .fields + .clone() + .into_iter() + .map(|field| { + field + .attrs + .clone() + .into_iter() + .map(|attribute| { + if is_custom_embed(&attribute)? { + Ok::<_, syn::Error>(Some((field.clone(), expand_tag(&attribute)?))) + } else { + Ok(None) + } + }) + .collect::, _>>() + }) + .collect::, _>>()? + .into_iter() + .flatten() + .flatten()) + } +} + +trait CustomFunction { + fn function_path(&self) -> Result; +} + +impl CustomFunction for ParseNestedMeta<'_> { + fn function_path(&self) -> Result { + // #[embed(embed_with = "...")] + let expr = self.value().unwrap().parse::().unwrap(); + let mut value = &expr; + while let syn::Expr::Group(e) = value { + value = &e.expr; + } + let string = if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit_str), + .. + }) = value + { + let suffix = lit_str.suffix(); + if !suffix.is_empty() { + return Err(syn::Error::new( + lit_str.span(), + format!("unexpected suffix `{}` on string literal", suffix), + )); + } + lit_str.clone() + } else { + return Err(syn::Error::new( + value.span(), + format!( + "expected {} attribute to be a string: `{} = \"...\"`", + EMBED_WITH, EMBED_WITH + ), + )); + }; + + string.parse() + } +} diff --git a/rig-macros/rig-macros-derive/src/lib.rs b/rig-macros/rig-macros-derive/src/lib.rs index 2a6e4aaa..239d9b3a 100644 --- a/rig-macros/rig-macros-derive/src/lib.rs +++ b/rig-macros/rig-macros-derive/src/lib.rs @@ -1,154 +1,15 @@ extern crate proc_macro; -use indoc::indoc; use proc_macro::TokenStream; -use quote::{quote, ToTokens}; -use syn::{ - meta::ParseNestedMeta, parse_macro_input, parse_quote, spanned::Spanned, Attribute, DataStruct, DeriveInput, ExprPath, Meta, Path -}; +use syn::{parse_macro_input, DeriveInput}; + +mod embedding; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html -const EMBED: &str = "embed"; -const EMBED_WITH: &str = "embed_with"; - #[proc_macro_derive(Embedding, attributes(embed))] pub fn derive_embed_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); - impl_embeddable_macro(&mut input) -} - -fn impl_embeddable_macro(input: &mut syn::DeriveInput) -> TokenStream { - let name = &input.ident; - - let embeddings = match &input.data { - syn::Data::Struct(data_struct) => { - basic_embed_fields(data_struct).for_each(|(_, field_type)| { - add_struct_bounds(&mut input.generics, &field_type) - }); - - // let basic_embed_fields = basic_embed_fields(data_struct) - // .map(|field_name| { - // quote! { - // self.#field_name.embeddable() - // } - // }) - // .collect::>(); - - let func_names = custom_trait_implementation(data_struct).unwrap(); - - false - } - _ => false, - }; - - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - - let gen = quote! { - impl #impl_generics Embeddable for #name #ty_generics #where_clause { - type Kind = String; - - fn embeddable(&self) -> Vec { - // vec![ - // #(#embeddings),* - // ].into_iter().flatten().collect() - println!("{}", #embeddings); - vec![] - } - } - }; - eprintln!("Generated code:\n{}", gen); - - gen.into() -} - -fn custom_trait_implementation(data_struct: &DataStruct) -> Result, syn::Error> { - Ok(data_struct - .fields - .clone() - .into_iter() - .map(|field| { - let mut path = None; - field.attrs.clone().into_iter().map(|attr| { - if attr.path().is_ident(EMBED) { - attr.parse_nested_meta(|meta| { - if meta.path.is_ident(EMBED_WITH) { - path = Some(parse_embed_with(&meta)?); - - // let tokens = meta.path.into_token_stream(); - }; - Ok(()) - }) - } else { - Ok(()) - } - }).collect::,_>>()?; - Ok::<_, syn::Error>(path) - }).collect::,_>>()? - .into_iter() - .filter_map(|i| i) - .collect()) -} - -fn parse_embed_with(meta: &ParseNestedMeta) -> Result { - // #[embed(embed_with = "...")] - let expr = meta.value().unwrap().parse::().unwrap(); - let mut value = &expr; - while let syn::Expr::Group(e) = value { - value = &e.expr; - } - let string = if let syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(lit_str), - .. - }) = value - { - let suffix = lit_str.suffix(); - if !suffix.is_empty() { - return Err(syn::Error::new( - lit_str.span(), - format!("unexpected suffix `{}` on string literal", suffix) - )) - } - lit_str.clone() - } else { - return Err(syn::Error::new( - value.span(), - format!("expected {} attribute to be a string: `{} = \"...\"`", EMBED_WITH, EMBED_WITH) - )) - }; - - string.parse() -} - -fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { - let where_clause = generics.make_where_clause(); - - where_clause.predicates.push(parse_quote! { - #field_type: Embeddable - }); -} - - -fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator { - data_struct.fields.clone().into_iter().filter_map(|field| { - let found_embed = field - .attrs - .clone() - .into_iter() - .any(|attribute| match attribute { - Attribute { - meta: Meta::Path(path), - .. - } => path.is_ident("embed"), - _ => false, - }); - match found_embed { - true => Some(( - field.ident.expect(""), - field.ty - )), - false => None, - } - }) + embedding::expand_derive_embedding(&mut input) } diff --git a/rig-macros/src/lib.rs b/rig-macros/src/lib.rs index b57bf150..d7151f03 100644 --- a/rig-macros/src/lib.rs +++ b/rig-macros/src/lib.rs @@ -8,8 +8,24 @@ trait Embeddable { fn embeddable(&self) -> Vec; } +#[derive(serde::Serialize)] +pub struct JobStruct { + job_title: String, + company: String, +} + +mod something { + use super::JobStruct; + + pub fn embeddable(input: &JobStruct) -> Vec { + vec![serde_json::to_string(input).unwrap()] + } +} + #[cfg(test)] mod tests { + use crate::JobStruct; + use super::{Embeddable, Kind}; use rig_macros_derive::Embedding; @@ -22,23 +38,27 @@ mod tests { } #[derive(Embedding)] - struct MyStruct { + struct SomeStruct { #[embed] id: usize, - #[embed(embed_with = "something")] name: String, + #[embed(embed_with = "super::something")] + job: JobStruct, } #[test] fn test_macro() { - let my_struct = MyStruct { + let job_struct = JobStruct { + job_title: "developer".to_string(), + company: "playgrounds".to_string(), + }; + let some_struct = SomeStruct { id: 1, name: "John".to_string(), + job: job_struct, }; - my_struct.embeddable(); - - // println!("{:?}", my_struct.embeddable()); + println!("{:?}", some_struct.embeddable()); assert!(false) } diff --git a/rig-macros/test.rs b/rig-macros/test.rs index f230af56..225a9b13 100644 --- a/rig-macros/test.rs +++ b/rig-macros/test.rs @@ -5,6 +5,7 @@ pub struct EmbeddingsBuilder { } trait Embeddable { + type Kind; // Return list of strings that need to be embedded. // Instead of Vec, should be Vec fn embeddable(&self) -> Vec; @@ -12,6 +13,28 @@ trait Embeddable { type EmbeddingVector = Vec; +impl> EmbeddingsBuilder { + pub fn build(&self) -> Result, EmbeddingError> { + self.documents.iter().map(|(doc, values_to_embed)| { + values_to_embed.iter().map(|value| { + let value_str = serde_json::to_string(value)?; + generate_embedding(value_str) + }) + }) + } +} + +impl> EmbeddingsBuilder { + pub fn build(&self) -> Result)>, EmbeddingError> { + self.documents.iter().map(|(doc, values_to_embed)| { + values_to_embed.iter().map(|value| { + let value_str = serde_json::to_string(value)?; + generate_embedding(value_str) + }) + }) + } +} + impl EmbeddingsBuilder { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { @@ -33,22 +56,6 @@ impl EmbeddingsBuilder { )); self } - - pub fn build(&self) -> Result)>, EmbeddingError> { - self.documents.iter().map(|(doc, values_to_embed)| { - values_to_embed.iter().map(|value| { - let value_str = serde_json::to_string(value)?; - generate_embedding(value_str) - }) - }) - } - - pub fn build_simple(&self) -> Result, EmbeddingError> { - self.documents.iter().map(|(doc, value_to_embed)| { - let value_str = serde_json::to_string(value_to_embed)?; - generate_embedding(value_str) - }) - } } From 0d1001119a3a850cdfb189e4cecaeda20bc9d1da Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 7 Oct 2024 16:56:29 -0400 Subject: [PATCH 06/47] refactor: move macro derive crate inside rig-core --- Cargo.lock | 12 +- Cargo.toml | 2 +- rig-core/Cargo.toml | 1 + .../rig-core-derive}/Cargo.toml | 2 +- .../rig-core-derive}/src/embedding.rs | 0 .../rig-core-derive}/src/lib.rs | 2 +- rig-core/src/embeddings.rs | 223 +++++++----------- rig-macros/Cargo.toml | 9 - rig-macros/src/lib.rs | 65 ----- rig-macros/test.rs | 157 ------------ 10 files changed, 87 insertions(+), 386 deletions(-) rename {rig-macros/rig-macros-derive => rig-core/rig-core-derive}/Cargo.toml (86%) rename {rig-macros/rig-macros-derive => rig-core/rig-core-derive}/src/embedding.rs (100%) rename {rig-macros/rig-macros-derive => rig-core/rig-core-derive}/src/lib.rs (86%) delete mode 100644 rig-macros/Cargo.toml delete mode 100644 rig-macros/src/lib.rs delete mode 100644 rig-macros/test.rs diff --git a/Cargo.lock b/Cargo.lock index 48f46588..ced6668c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1223,6 +1223,7 @@ dependencies = [ "futures", "ordered-float", "reqwest", + "rig-core-derive", "schemars", "serde", "serde_json", @@ -1233,16 +1234,7 @@ dependencies = [ ] [[package]] -name = "rig-macros" -version = "0.1.0" -dependencies = [ - "rig-macros-derive", - "serde", - "serde_json", -] - -[[package]] -name = "rig-macros-derive" +name = "rig-core-derive" version = "0.1.0" dependencies = [ "indoc", diff --git a/Cargo.toml b/Cargo.toml index d3a0c372..a37ba0e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" members = [ - "rig-core", "rig-macros", "rig-macros/rig-macros-derive", + "rig-core", "rig-core/rig-core-derive", "rig-mongodb", ] diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 28561465..be83f04f 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -22,6 +22,7 @@ futures = "0.3.29" ordered-float = "4.2.0" schemars = "0.8.16" thiserror = "1.0.61" +rig-core-derive = { path = "./rig-core-derive" } [dev-dependencies] anyhow = "1.0.75" diff --git a/rig-macros/rig-macros-derive/Cargo.toml b/rig-core/rig-core-derive/Cargo.toml similarity index 86% rename from rig-macros/rig-macros-derive/Cargo.toml rename to rig-core/rig-core-derive/Cargo.toml index f2f22f6d..8f8a45d7 100644 --- a/rig-macros/rig-macros-derive/Cargo.toml +++ b/rig-core/rig-core-derive/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "rig-macros-derive" +name = "rig-core-derive" version = "0.1.0" edition = "2021" diff --git a/rig-macros/rig-macros-derive/src/embedding.rs b/rig-core/rig-core-derive/src/embedding.rs similarity index 100% rename from rig-macros/rig-macros-derive/src/embedding.rs rename to rig-core/rig-core-derive/src/embedding.rs diff --git a/rig-macros/rig-macros-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs similarity index 86% rename from rig-macros/rig-macros-derive/src/lib.rs rename to rig-core/rig-core-derive/src/lib.rs index 239d9b3a..eb7abaa9 100644 --- a/rig-macros/rig-macros-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -8,7 +8,7 @@ mod embedding; // https://doc.rust-lang.org/reference/procedural-macros.html #[proc_macro_derive(Embedding, attributes(embed))] -pub fn derive_embed_trait(item: TokenStream) -> TokenStream { +pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); embedding::expand_derive_embedding(&mut input) diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index 2d40bbc5..c72255df 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -147,15 +147,21 @@ pub struct DocumentEmbeddings { pub embeddings: Vec, } -type Embeddings = Vec; +struct SingleEmbedding; +struct ManyEmbedding; + +pub trait Embeddable { + type Kind; + fn embeddable(&self) -> Vec; +} /// Builder for creating a collection of embeddings -pub struct EmbeddingsBuilder { +pub struct EmbeddingsBuilder { model: M, - documents: Vec<(String, serde_json::Value, Vec)>, + documents: Vec<(D, Vec)>, } -impl EmbeddingsBuilder { +impl EmbeddingsBuilder { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { @@ -164,169 +170,102 @@ impl EmbeddingsBuilder { } } - /// Add a simple document to the embedding collection. - /// The provided document string will be used for the embedding. - pub fn simple_document(mut self, id: &str, document: &str) -> Self { - self.documents.push(( - id.to_string(), - serde_json::Value::String(document.to_string()), - vec![document.to_string()], - )); - self - } - - /// Add multiple documents to the embedding collection. - /// Each element of the vector is a tuple of the form (id, document). - pub fn simple_documents(mut self, documents: Vec<(String, String)>) -> Self { - self.documents - .extend(documents.into_iter().map(|(id, document)| { - ( - id, - serde_json::Value::String(document.clone()), - vec![document], - ) - })); - self - } - - /// Add a tool to the embedding collection. - /// The `tool.context()` corresponds to the document being stored while - /// `tool.embedding_docs()` corresponds to the documents that will be used to generate the embeddings. - pub fn tool(mut self, tool: impl ToolEmbedding + 'static) -> Result { - self.documents.push(( - tool.name(), - serde_json::to_value(tool.context())?, - tool.embedding_docs(), - )); - Ok(self) - } - - /// Add the tools from the given toolset to the embedding collection. - pub fn tools(mut self, toolset: &ToolSet) -> Result { - for (name, tool) in toolset.tools.iter() { - if let ToolType::Embedding(tool) = tool { - self.documents.push(( - name.clone(), - tool.context().map_err(|e| { - EmbeddingError::DocumentError(format!( - "Failed to generate context for tool {}: {}", - name, e - )) - })?, - tool.embedding_docs(), - )); - } - } - Ok(self) - } - - /// Add a document to the embedding collection. - /// `embed_documents` are the documents that will be used to generate the embeddings - /// for `document`. - pub fn document( - mut self, - id: &str, - document: T, - embed_documents: Vec, - ) -> Self { - self.documents.push(( - id.to_string(), - serde_json::to_value(document).expect("Document should serialize"), - embed_documents, - )); - self - } + pub fn document(mut self, document: D) -> Self { + let embed_targets = document.embeddable(); - /// Add multiple documents to the embedding collection. - /// Each element of the vector is a tuple of the form (id, document, embed_documents). - pub fn documents(mut self, documents: Vec<(String, T, Vec)>) -> Self { - self.documents.extend( - documents - .into_iter() - .map(|(id, document, embed_documents)| { - ( - id, - serde_json::to_value(document).expect("Document should serialize"), - embed_documents, - ) - }), - ); - self - } - - /// Add a json document to the embedding collection. - pub fn json_document( - mut self, - id: &str, - document: serde_json::Value, - embed_documents: Vec, - ) -> Self { - self.documents - .push((id.to_string(), document, embed_documents)); - self - } - - /// Add multiple json documents to the embedding collection. - pub fn json_documents( - mut self, - documents: Vec<(String, serde_json::Value, Vec)>, - ) -> Self { - self.documents.extend(documents); + self.documents.push((document, embed_targets)); self } +} - /// Generate the embeddings for the given documents - pub async fn build(self) -> Result { - // Create a temporary store for the documents +impl + Send + Clone> + EmbeddingsBuilder +{ + pub async fn build(self) -> Result)>, EmbeddingError> { let documents_map = self .documents + .clone() .into_iter() - .map(|(id, document, docs)| (id, (document, docs))) + .enumerate() + .map(|(id, (document, _))| (id, document)) .collect::>(); - let embeddings = stream::iter(documents_map.iter()) + let embeddings = stream::iter(self.documents.into_iter().enumerate()) // Flatten the documents - .flat_map(|(id, (_, docs))| { - stream::iter(docs.iter().map(|doc| (id.clone(), doc.clone()))) + .flat_map(|(i, (_, embed_targets))| { + stream::iter( + embed_targets + .into_iter() + .map(move |target| (i, target.clone())), + ) }) // Chunk them into N (the emebdding API limit per request) .chunks(M::MAX_DOCUMENTS) // Generate the embeddings .map(|docs| async { - let (ids, docs): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); Ok::<_, EmbeddingError>( - ids.into_iter() - .zip(self.model.embed_documents(docs).await?.into_iter()) + documents + .into_iter() + .zip(self.model.embed_documents(embed_targets).await?.into_iter()) .collect::>(), ) }) .boxed() // Parallelize the embeddings generation over 10 concurrent requests .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - .try_fold(vec![], |mut acc, mut embeddings| async move { - Ok({ - acc.append(&mut embeddings); - acc - }) + // .try_collect::>() + // .await; + .try_fold(HashMap::new(), |mut acc, mut embeddings| async move { + embeddings.into_iter().for_each(|(i, embedding)| { + acc.entry(i).or_insert(vec![]).push(embedding); + }); + + Ok(acc) }) - .await?; + .await? + .iter() + .fold(vec![], |mut acc, (i, embeddings_vec)| { + acc.push(( + documents_map.get(i).cloned().unwrap(), + embeddings_vec.clone(), + )); + acc + }); - // Assemble the DocumentEmbeddings - let mut document_embeddings: HashMap = HashMap::new(); - embeddings.into_iter().for_each(|(id, embedding)| { - let (document, _) = documents_map.get(&id).expect("Document not found"); - let document_embedding = - document_embeddings - .entry(id.clone()) - .or_insert_with(|| DocumentEmbeddings { - id: id.clone(), - document: document.clone(), - embeddings: vec![], - }); + Ok(embeddings) + } +} - document_embedding.embeddings.push(embedding); - }); +impl + Send + Clone> + EmbeddingsBuilder +{ + pub async fn build(self) -> Result, EmbeddingError> { + let embeddings = + stream::iter(self.documents.into_iter().map(|(document, embed_target)| { + (document, embed_target.first().cloned().unwrap()) + })) + // Chunk them into N (the emebdding API limit per request) + .chunks(M::MAX_DOCUMENTS) + // Generate the embeddings + .map(|docs| async { + let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + Ok::<_, EmbeddingError>( + documents + .into_iter() + .zip(self.model.embed_documents(embed_targets).await?.into_iter()) + .collect::>(), + ) + }) + .boxed() + // Parallelize the embeddings generation over 10 concurrent requests + .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) + .try_fold(vec![], |mut acc, embeddings| async move { + acc.extend(embeddings); + Ok(acc) + }) + .await?; - Ok(document_embeddings.into_values().collect()) + Ok(embeddings) } } diff --git a/rig-macros/Cargo.toml b/rig-macros/Cargo.toml deleted file mode 100644 index d1cd2cd9..00000000 --- a/rig-macros/Cargo.toml +++ /dev/null @@ -1,9 +0,0 @@ -[package] -name = "rig-macros" -version = "0.1.0" -edition = "2021" - -[dependencies] -serde_json = "1.0.128" -rig-macros-derive = { path = "./rig-macros-derive" } -serde = {version = "1.0.210", features = ["derive"]} diff --git a/rig-macros/src/lib.rs b/rig-macros/src/lib.rs deleted file mode 100644 index d7151f03..00000000 --- a/rig-macros/src/lib.rs +++ /dev/null @@ -1,65 +0,0 @@ -enum Kind { - Single, - Many, -} - -trait Embeddable { - type Kind; - fn embeddable(&self) -> Vec; -} - -#[derive(serde::Serialize)] -pub struct JobStruct { - job_title: String, - company: String, -} - -mod something { - use super::JobStruct; - - pub fn embeddable(input: &JobStruct) -> Vec { - vec![serde_json::to_string(input).unwrap()] - } -} - -#[cfg(test)] -mod tests { - use crate::JobStruct; - - use super::{Embeddable, Kind}; - use rig_macros_derive::Embedding; - - impl Embeddable for usize { - type Kind = Kind; - - fn embeddable(&self) -> Vec { - vec![self.to_string()] - } - } - - #[derive(Embedding)] - struct SomeStruct { - #[embed] - id: usize, - name: String, - #[embed(embed_with = "super::something")] - job: JobStruct, - } - - #[test] - fn test_macro() { - let job_struct = JobStruct { - job_title: "developer".to_string(), - company: "playgrounds".to_string(), - }; - let some_struct = SomeStruct { - id: 1, - name: "John".to_string(), - job: job_struct, - }; - - println!("{:?}", some_struct.embeddable()); - - assert!(false) - } -} diff --git a/rig-macros/test.rs b/rig-macros/test.rs deleted file mode 100644 index 225a9b13..00000000 --- a/rig-macros/test.rs +++ /dev/null @@ -1,157 +0,0 @@ -/// Builder for creating a collection of embeddings -pub struct EmbeddingsBuilder { - model: M, - documents: Vec<(T, Vec)>, -} - -trait Embeddable { - type Kind; - // Return list of strings that need to be embedded. - // Instead of Vec, should be Vec - fn embeddable(&self) -> Vec; -} - -type EmbeddingVector = Vec; - -impl> EmbeddingsBuilder { - pub fn build(&self) -> Result, EmbeddingError> { - self.documents.iter().map(|(doc, values_to_embed)| { - values_to_embed.iter().map(|value| { - let value_str = serde_json::to_string(value)?; - generate_embedding(value_str) - }) - }) - } -} - -impl> EmbeddingsBuilder { - pub fn build(&self) -> Result)>, EmbeddingError> { - self.documents.iter().map(|(doc, values_to_embed)| { - values_to_embed.iter().map(|value| { - let value_str = serde_json::to_string(value)?; - generate_embedding(value_str) - }) - }) - } -} - -impl EmbeddingsBuilder { - /// Create a new embedding builder with the given embedding model - pub fn new(model: M) -> Self { - Self { - model, - documents: vec![], - } - } - - pub fn add( - mut self, - document: T, - ) -> Self { - let embed_documents: Vec = document.embeddable(); - - self.documents.push(( - document, - embed_documents, - )); - self - } -} - - -// Example -#[derive(Embeddable)] -struct DictionaryEntry { - word: String, - #[embed] - definitions: String, -} - -#[derive(Embeddable)] -struct MetadataEmbedding { - pub id: String, - #[embed(with = serde_json::to_value)] - pub content: CategoryMetadata, - pub created: Option>, - pub modified: Option>, - pub dataset_ids: Vec, -} - -#[derive(serde::Serialize)] -struct CategoryMetadata { - pub name: String, - pub description: String, - pub tags: Vec, - pub links: Vec, -} - -// Inside macro: -impl Embeddable for DictionaryEntry { - fn embeddable(&self) -> Vec { - // Find the field tagged with #[embed] and return its value - // If there are no embedding tags, return the entire struct - } -} - -fn main() { - let embeddings: Vec<(DictionaryEntry, Vec)> = EmbeddingsBuilder::new(model.clone()) - .add(DictionaryEntry::new("blah", vec!["definition of blah"])) - .add(DictionaryEntry::new("foo", vec!["definition of foo"])) - .build()?; - - // In relational vector store like LanceDB, need to flatten result (create row for each item in definitions vector): - // Column: word (string) - // Column: definition (vector) - - // In document vector store like MongoDB, might need to merge the vector results back with their corresponding definition string: - // Field: word (string) - // Field: definitions - // // Field: definition (string) - // // Field: vector - - Ok(()) -} - - - -// Iterations: -// 1 - Multiple fields to embed? -#[derive(Embedding)] -struct DictionaryEntry { - word: String, - #[embed] - definitions: Vec, - #[embed] - synonyms: Vec -} - -// 2 - Embed recursion? Ex: -#[derive(Embedding)] -struct DictionaryEntry { - word: String, - #[embed] - definitions: Vec, -} -struct Definition { - definition: String, - #[embed] - links: Vec -} - -// { -// word: "blah", -// definitions: [ -// { -// definition: "definition of blah", -// links: ["link1", "link2"] -// }, -// { -// definition: "another definition for blah", -// links: ["link3"] -// } -// ] -// } - -// blah | definition of blah | link1 | embedding for link1 -// blah | definition of blah | link2 | embedding for link2 -// blah | another definition for blah | link3 | embedding for link3 \ No newline at end of file From 79754aa954b55e5d6e5fc466f5f880c9d9535f13 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 8 Oct 2024 14:28:03 -0400 Subject: [PATCH 07/47] feat: replace embedding logic with new embeddable trait and macro --- Cargo.lock | 6 +- rig-core/Cargo.toml | 4 +- rig-core/rig-core-derive/Cargo.toml | 2 +- .../src/{embedding.rs => embeddable.rs} | 40 +++++- rig-core/rig-core-derive/src/lib.rs | 6 +- rig-core/src/embeddings.rs | 135 ++++++++++++------ rig-core/src/providers/cohere.rs | 8 +- rig-core/src/providers/openai.rs | 7 +- rig-lancedb/Cargo.toml | 1 + rig-lancedb/examples/fixtures/lib.rs | 67 ++++++--- .../examples/vector_search_local_ann.rs | 26 ++-- .../examples/vector_search_local_enn.rs | 7 +- rig-lancedb/examples/vector_search_s3_ann.rs | 26 ++-- rig-mongodb/Cargo.toml | 2 + rig-mongodb/examples/vector_search_mongodb.rs | 71 ++++++--- rig-mongodb/src/lib.rs | 102 +++---------- 16 files changed, 287 insertions(+), 223 deletions(-) rename rig-core/rig-core-derive/src/{embedding.rs => embeddable.rs} (83%) diff --git a/Cargo.lock b/Cargo.lock index 19c004d7..311b2862 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3997,7 +3997,7 @@ dependencies = [ "futures", "ordered-float", "reqwest 0.11.27", - "rig-core-derive", + "rig-derive", "schemars", "serde", "serde_json", @@ -4008,7 +4008,7 @@ dependencies = [ ] [[package]] -name = "rig-core-derive" +name = "rig-derive" version = "0.1.0" dependencies = [ "indoc", @@ -4025,6 +4025,7 @@ dependencies = [ "futures", "lancedb", "rig-core", + "rig-derive", "serde", "serde_json", "tokio", @@ -4038,6 +4039,7 @@ dependencies = [ "futures", "mongodb", "rig-core", + "rig-derive", "serde", "serde_json", "tokio", diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 49efafd0..658faf4b 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -23,9 +23,9 @@ futures = "0.3.29" ordered-float = "4.2.0" schemars = "0.8.16" thiserror = "1.0.61" -rig-core-derive = { path = "./rig-core-derive" } +rig-derive = { path = "./rig-core-derive" } [dev-dependencies] anyhow = "1.0.75" tokio = { version = "1.34.0", features = ["full"] } -tracing-subscriber = "0.3.18" +tracing-subscriber = "0.3.18" \ No newline at end of file diff --git a/rig-core/rig-core-derive/Cargo.toml b/rig-core/rig-core-derive/Cargo.toml index 8f8a45d7..008f492b 100644 --- a/rig-core/rig-core-derive/Cargo.toml +++ b/rig-core/rig-core-derive/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "rig-core-derive" +name = "rig-derive" version = "0.1.0" edition = "2021" diff --git a/rig-core/rig-core-derive/src/embedding.rs b/rig-core/rig-core-derive/src/embeddable.rs similarity index 83% rename from rig-core/rig-core-derive/src/embedding.rs rename to rig-core/rig-core-derive/src/embeddable.rs index 2bfdb0ab..e2c34afe 100644 --- a/rig-core/rig-core-derive/src/embedding.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -1,8 +1,8 @@ use proc_macro::TokenStream; use quote::quote; use syn::{ - meta::ParseNestedMeta, parse_quote, punctuated::Punctuated, spanned::Spanned, Attribute, - DataStruct, ExprPath, Meta, Token, + meta::ParseNestedMeta, parse_quote, parse_str, punctuated::Punctuated, spanned::Spanned, + Attribute, DataStruct, ExprPath, Meta, Token, }; const EMBED: &str = "embed"; @@ -11,7 +11,7 @@ const EMBED_WITH: &str = "embed_with"; pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { let name = &input.ident; - let func_calls = + let (func_calls, embed_kind) = match &input.data { syn::Data::Struct(data_struct) => { // Handles fields tagged with #[embed] @@ -21,6 +21,7 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { add_struct_bounds(&mut input.generics, &field.ty); let field_name = field.ident; + quote! { self.#field_name.embeddable() } @@ -38,9 +39,9 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { }, )); - function_calls + (function_calls, data_struct.embed_kind().unwrap()) } - _ => vec![], + _ => panic!("Embeddable can only be derived for structs"), }; // Import the paths to the custom functions. @@ -60,10 +61,13 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { + use rig::embeddings::Embeddable; + use rig::embeddings::#embed_kind; + #(#custom_func_paths);* impl #impl_generics Embeddable for #name #ty_generics #where_clause { - type Kind = String; + type Kind = #embed_kind; fn embeddable(&self) -> Vec { vec![ @@ -86,6 +90,13 @@ fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { }); } +fn embed_kind(field: &syn::Field) -> Result { + match &field.ty { + syn::Type::Array(_) => parse_str("ManyEmbedding"), + _ => parse_str("SingleEmbedding"), + } +} + trait AttributeParser { /// Finds and returns fields with simple #[embed] attribute tags only. fn basic_embed_fields(&self) -> impl Iterator; @@ -94,6 +105,23 @@ trait AttributeParser { fn custom_embed_fields( &self, ) -> Result, syn::Error>; + + /// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is 1, + /// returns the kind of embedding that field should be. + /// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is greater than 1, + /// return ManyEmbedding. + fn embed_kind(&self) -> Result { + let fields = self + .basic_embed_fields() + .chain(self.custom_embed_fields().unwrap().map(|(f, _)| f)) + .collect::>(); + + if fields.len() == 1 { + fields.iter().map(embed_kind).next().unwrap() + } else { + parse_str("ManyEmbedding") + } + } } impl AttributeParser for DataStruct { diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index eb7abaa9..19f6a845 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -2,14 +2,14 @@ extern crate proc_macro; use proc_macro::TokenStream; use syn::{parse_macro_input, DeriveInput}; -mod embedding; +mod embeddable; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html -#[proc_macro_derive(Embedding, attributes(embed))] +#[proc_macro_derive(Embed, attributes(embed))] pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); - embedding::expand_derive_embedding(&mut input) + embeddable::expand_derive_embedding(&mut input) } diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index c1d281db..b044b406 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -38,13 +38,11 @@ //! // ... //! ``` -use std::{cmp::max, collections::HashMap}; +use std::{cmp::max, collections::HashMap, marker::PhantomData}; use futures::{stream, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; -use crate::tool::{ToolEmbedding, ToolSet, ToolType}; - #[derive(Debug, thiserror::Error)] pub enum EmbeddingError { /// Http error (e.g.: connection error, timeout, etc.) @@ -102,7 +100,7 @@ pub trait EmbeddingModel: Clone + Sync + Send { } /// Struct that holds a single document and its embedding. -#[derive(Clone, Default, Deserialize, Serialize)] +#[derive(Clone, Default, Deserialize, Serialize, Debug)] pub struct Embedding { /// The document that was embedded pub document: String, @@ -149,25 +147,29 @@ pub struct DocumentEmbeddings { pub document: serde_json::Value, pub embeddings: Vec, } - -struct SingleEmbedding; -struct ManyEmbedding; +pub trait EmbeddingKind {} +pub struct SingleEmbedding; +impl EmbeddingKind for SingleEmbedding {} +pub struct ManyEmbedding; +impl EmbeddingKind for ManyEmbedding {} pub trait Embeddable { - type Kind; + type Kind: EmbeddingKind; fn embeddable(&self) -> Vec; } /// Builder for creating a collection of embeddings -pub struct EmbeddingsBuilder { +pub struct EmbeddingsBuilder { + kind: PhantomData, model: M, documents: Vec<(D, Vec)>, } -impl EmbeddingsBuilder { +impl, K: EmbeddingKind> EmbeddingsBuilder { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { + kind: PhantomData, model, documents: vec![], } @@ -179,12 +181,22 @@ impl EmbeddingsBuilder { self.documents.push((document, embed_targets)); self } + + pub fn documents(mut self, documents: Vec) -> EmbeddingsBuilder { + documents.into_iter().for_each(|doc| { + let embed_targets = doc.embeddable(); + + self.documents.push((doc, embed_targets)); + }); + + self + } } -impl + Send + Clone> - EmbeddingsBuilder +impl + EmbeddingsBuilder { - pub async fn build(self) -> Result)>, EmbeddingError> { + pub async fn build(&self) -> Result)>, EmbeddingError> { let documents_map = self .documents .clone() @@ -193,7 +205,7 @@ impl + Send + Clone> .map(|(id, (document, _))| (id, document)) .collect::>(); - let embeddings = stream::iter(self.documents.into_iter().enumerate()) + let embeddings = stream::iter(self.documents.clone().into_iter().enumerate()) // Flatten the documents .flat_map(|(i, (_, embed_targets))| { stream::iter( @@ -219,13 +231,16 @@ impl + Send + Clone> .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) // .try_collect::>() // .await; - .try_fold(HashMap::new(), |mut acc, mut embeddings| async move { - embeddings.into_iter().for_each(|(i, embedding)| { - acc.entry(i).or_insert(vec![]).push(embedding); - }); + .try_fold( + HashMap::new(), + |mut acc: HashMap<_, Vec<_>>, embeddings| async move { + embeddings.into_iter().for_each(|(i, embedding)| { + acc.entry(i).or_default().push(embedding); + }); - Ok(acc) - }) + Ok(acc) + }, + ) .await? .iter() .fold(vec![], |mut acc, (i, embeddings_vec)| { @@ -240,35 +255,61 @@ impl + Send + Clone> } } -impl + Send + Clone> - EmbeddingsBuilder +impl + EmbeddingsBuilder { - pub async fn build(self) -> Result, EmbeddingError> { - let embeddings = - stream::iter(self.documents.into_iter().map(|(document, embed_target)| { - (document, embed_target.first().cloned().unwrap()) - })) - // Chunk them into N (the emebdding API limit per request) - .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings - .map(|docs| async { - let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - Ok::<_, EmbeddingError>( - documents - .into_iter() - .zip(self.model.embed_documents(embed_targets).await?.into_iter()) - .collect::>(), - ) - }) - .boxed() - // Parallelize the embeddings generation over 10 concurrent requests - .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - .try_fold(vec![], |mut acc, embeddings| async move { - acc.extend(embeddings); - Ok(acc) - }) - .await?; + pub async fn build(&self) -> Result, EmbeddingError> { + let embeddings = stream::iter( + self.documents + .clone() + .into_iter() + .map(|(document, embed_target)| (document, embed_target.first().cloned().unwrap())), + ) + // Chunk them into N (the emebdding API limit per request) + .chunks(M::MAX_DOCUMENTS) + // Generate the embeddings + .map(|docs| async { + let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + Ok::<_, EmbeddingError>( + documents + .into_iter() + .zip(self.model.embed_documents(embed_targets).await?.into_iter()) + .collect::>(), + ) + }) + .boxed() + // Parallelize the embeddings generation over 10 concurrent requests + .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) + .try_fold(vec![], |mut acc, embeddings| async move { + acc.extend(embeddings); + Ok(acc) + }) + .await?; Ok(embeddings) } } + +impl Embeddable for String { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec { + vec![self.clone()] + } +} + +impl Embeddable for i32 { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec { + vec![self.to_string()] + } +} + +impl Embeddable for Vec { + type Kind = ManyEmbedding; + + fn embeddable(&self) -> Vec { + self.iter().flat_map(|i| i.embeddable()).collect() + } +} diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index ae874b21..c93709c7 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -13,7 +13,7 @@ use std::collections::HashMap; use crate::{ agent::AgentBuilder, completion::{self, CompletionError}, - embeddings::{self, EmbeddingError, EmbeddingsBuilder}, + embeddings::{self, Embeddable, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, json_utils, }; @@ -85,7 +85,11 @@ impl Client { EmbeddingModel::new(self.clone(), model, input_type, ndims) } - pub fn embeddings(&self, model: &str, input_type: &str) -> EmbeddingsBuilder { + pub fn embeddings( + &self, + model: &str, + input_type: &str, + ) -> EmbeddingsBuilder { EmbeddingsBuilder::new(self.embedding_model(model, input_type)) } diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 8262e6ce..6bd1711f 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -11,7 +11,7 @@ use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, - embeddings::{self, EmbeddingError}, + embeddings::{self, Embeddable, EmbeddingError}, extractor::ExtractorBuilder, json_utils, }; @@ -121,7 +121,10 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder { + pub fn embeddings( + &self, + model: &str, + ) -> embeddings::EmbeddingsBuilder { embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) } diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index 031df2d3..6ee41d8d 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] lancedb = "0.10.0" rig-core = { path = "../rig-core", version = "0.2.1" } +rig-derive = { path = "../rig-core/rig-core-derive" } arrow-array = "52.2.0" serde_json = "1.0.128" serde = "1.0.210" diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 9a91432e..322c0a6a 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -2,13 +2,52 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; -use rig::embeddings::DocumentEmbeddings; +use rig::embeddings::Embedding; +use rig_derive::Embed; +use serde::Deserialize; + +#[derive(Embed, Clone)] +pub struct FakeDefinition { + id: String, + #[embed] + definition: String, +} + +#[derive(Deserialize, Debug)] +pub struct VectorSearchResult { + pub id: String, + pub definition: String, +} + +pub fn fake_definitions() -> Vec { + vec![ + FakeDefinition { + id: "doc0".to_string(), + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string() + }, + FakeDefinition { + id: "doc1".to_string(), + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string() + }, + FakeDefinition { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + } + ] +} + +pub fn fake_definition(id: String) -> FakeDefinition { + FakeDefinition { + id, + definition: "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string() + } +} // Schema of table in LanceDB. pub fn schema(dims: usize) -> Schema { Schema::new(Fields::from(vec![ Field::new("id", DataType::Utf8, false), - Field::new("content", DataType::Utf8, false), + Field::new("definition", DataType::Utf8, false), Field::new( "embedding", DataType::FixedSizeList( @@ -22,46 +61,34 @@ pub fn schema(dims: usize) -> Schema { // Convert DocumentEmbeddings objects to a RecordBatch. pub fn as_record_batch( - records: Vec, + records: Vec<(FakeDefinition, Embedding)>, dims: usize, ) -> Result { let id = StringArray::from_iter_values( records .iter() - .flat_map(|record| (0..record.embeddings.len()).map(|i| format!("{}-{i}", record.id))) + .map(|(FakeDefinition { id, .. }, _)| id) .collect::>(), ); - let content = StringArray::from_iter_values( + let definition = StringArray::from_iter_values( records .iter() - .flat_map(|record| { - record - .embeddings - .iter() - .map(|embedding| embedding.document.clone()) - }) + .map(|(_, Embedding { document, .. })| document) .collect::>(), ); let embedding = FixedSizeListArray::from_iter_primitive::( records .into_iter() - .flat_map(|record| { - record - .embeddings - .into_iter() - .map(|embedding| embedding.vec.into_iter().map(Some).collect::>()) - .map(Some) - .collect::>() - }) + .map(|(_, Embedding { vec, .. })| Some(vec.into_iter().map(Some).collect::>())) .collect::>(), dims as i32, ); RecordBatch::try_from_iter(vec![ ("id", Arc::new(id) as ArrayRef), - ("content", Arc::new(content) as ArrayRef), + ("definition", Arc::new(definition) as ArrayRef), ("embedding", Arc::new(embedding) as ArrayRef), ]) } diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 3ecd6b23..3ac925bb 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema}; +use fixture::{as_record_batch, fake_definition, fake_definitions, schema, VectorSearchResult}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ @@ -9,17 +9,10 @@ use rig::{ providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -use serde::Deserialize; #[path = "./fixtures/lib.rs"] mod fixture; -#[derive(Deserialize, Debug)] -pub struct VectorSearchResult { - pub id: String, - pub content: String, -} - #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). @@ -32,18 +25,15 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; - // Set up test data for RAG demo - let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); - - // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. - let definitions = vec![definition; 256]; - // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") - .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") - .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") - .simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect()) + .documents(fake_definitions()) + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + .documents( + (0..256) + .map(|i| fake_definition(format!("doc{}", i))) + .collect(), + ) .build() .await?; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 5932dcd0..1bf69481 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema}; +use fixture::{as_record_batch, fake_definitions, schema}; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, @@ -21,10 +21,9 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") - .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") - .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") + .documents(fake_definitions()) .build() .await?; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 70f0c8c5..358a1755 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema}; +use fixture::{as_record_batch, fake_definition, fake_definitions, schema, VectorSearchResult}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, @@ -9,17 +9,10 @@ use rig::{ vector_store::VectorStoreIndex, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -use serde::Deserialize; #[path = "./fixtures/lib.rs"] mod fixture; -#[derive(Deserialize, Debug)] -pub struct VectorSearchResult { - pub id: String, - pub content: String, -} - // Note: see docs to deploy LanceDB on other cloud providers such as google and azure. // https://lancedb.github.io/lancedb/guides/storage/ #[tokio::main] @@ -38,18 +31,15 @@ async fn main() -> Result<(), anyhow::Error> { .execute() .await?; - // Set up test data for RAG demo - let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); - - // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. - let definitions = vec![definition; 256]; - // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") - .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") - .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") - .simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect()) + .documents(fake_definitions()) + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + .documents( + (0..256) + .map(|i| fake_definition(format!("doc{}", i))) + .collect(), + ) .build() .await?; diff --git a/rig-mongodb/Cargo.toml b/rig-mongodb/Cargo.toml index 6f313838..78b48892 100644 --- a/rig-mongodb/Cargo.toml +++ b/rig-mongodb/Cargo.toml @@ -13,6 +13,8 @@ repository = "https://github.com/0xPlaygrounds/rig" futures = "0.3.30" mongodb = "2.8.2" rig-core = { path = "../rig-core", version = "0.2.1" } +rig-derive = { path = "../rig-core/rig-core-derive" } + serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.117" tracing = "0.1.40" diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 3d062de3..4ddf1f99 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,13 +1,28 @@ -use mongodb::{options::ClientOptions, Client as MongoClient, Collection}; +use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; +use rig::{embeddings::Embedding, providers::openai::TEXT_EMBEDDING_ADA_002}; +use rig_derive::Embed; +use serde::{Deserialize, Serialize}; use std::env; use rig::{ - embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, - providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{VectorStore, VectorStoreIndex}, + embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; +#[derive(Embed, Clone)] +struct FakeDefinition { + id: String, + #[embed] + definition: String, +} + +#[derive(Serialize, Debug, Deserialize)] +struct Document { + #[serde(rename = "_id")] + id: String, + definition: Embedding, +} + #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client @@ -25,38 +40,62 @@ async fn main() -> Result<(), anyhow::Error> { MongoClient::with_options(options).expect("MongoDB client options should be valid"); // Initialize MongoDB vector store - let collection: Collection = mongodb_client + let collection: Collection = mongodb_client .database("knowledgebase") .collection("context"); - let mut vector_store = MongoDbVectorStore::new(collection); - // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + let fake_definitions = vec![ + FakeDefinition { + id: "doc0".to_string(), + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string() + }, + FakeDefinition { + id: "doc1".to_string(), + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string() + }, + FakeDefinition { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + } + ]; + let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .documents(fake_definitions) .build() .await?; - // Add embeddings to vector store - match vector_store.add_documents(embeddings).await { + let mongo_documents = embeddings + .iter() + .map(|(FakeDefinition { id, .. }, embedding)| Document { + id: id.clone(), + definition: embedding.clone(), + }) + .collect::>(); + + match collection.insert_many(mongo_documents, None).await { Ok(_) => println!("Documents added successfully"), Err(e) => println!("Error adding documents: {:?}", e), - } + }; + + let vector_store = MongoDbVectorStore::new(collection); // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model, "vector_index", SearchParams::default()); + let index = vector_store.index( + model, + "definitions_vector_index", + SearchParams::new("definition.vec"), + ); // Query the index let results = index - .top_n::("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.document)) + .map(|(score, id, doc)| (score, id, doc.definition.document)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 43869989..5ce33105 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -2,84 +2,23 @@ use futures::StreamExt; use mongodb::bson::{self, doc}; use rig::{ - embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel}, - vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, + embeddings::{Embedding, EmbeddingModel}, + vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; /// A MongoDB vector store. -pub struct MongoDbVectorStore { - collection: mongodb::Collection, +pub struct MongoDbVectorStore { + collection: mongodb::Collection, } fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } -impl VectorStore for MongoDbVectorStore { - type Q = mongodb::bson::Document; - - async fn add_documents( - &mut self, - documents: Vec, - ) -> Result<(), VectorStoreError> { - self.collection - .insert_many(documents, None) - .await - .map_err(mongodb_to_rig_error)?; - Ok(()) - } - - async fn get_document_embeddings( - &self, - id: &str, - ) -> Result, VectorStoreError> { - self.collection - .find_one(doc! { "_id": id }, None) - .await - .map_err(mongodb_to_rig_error) - } - - async fn get_document serde::Deserialize<'a>>( - &self, - id: &str, - ) -> Result, VectorStoreError> { - Ok(self - .collection - .clone_with_type::() - .aggregate( - [ - doc! {"$match": { "_id": id}}, - doc! {"$project": { "document": 1 }}, - doc! {"$replaceRoot": { "newRoot": "$document" }}, - ], - None, - ) - .await - .map_err(mongodb_to_rig_error)? - .with_type::() - .next() - .await - .transpose() - .map_err(mongodb_to_rig_error)? - .map(|doc| serde_json::from_str(&doc)) - .transpose()?) - } - - async fn get_document_by_query( - &self, - query: Self::Q, - ) -> Result, VectorStoreError> { - self.collection - .find_one(query, None) - .await - .map_err(mongodb_to_rig_error) - } -} - -impl MongoDbVectorStore { +impl MongoDbVectorStore { /// Create a new `MongoDbVectorStore` from a MongoDB collection. - pub fn new(collection: mongodb::Collection) -> Self { + pub fn new(collection: mongodb::Collection) -> Self { Self { collection } } @@ -92,20 +31,20 @@ impl MongoDbVectorStore { model: M, index_name: &str, search_params: SearchParams, - ) -> MongoDbVectorIndex { + ) -> MongoDbVectorIndex { MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params) } } /// A vector index for a MongoDB collection. -pub struct MongoDbVectorIndex { - collection: mongodb::Collection, +pub struct MongoDbVectorIndex { + collection: mongodb::Collection, model: M, index_name: String, search_params: SearchParams, } -impl MongoDbVectorIndex { +impl MongoDbVectorIndex { /// Vector search stage of aggregation pipeline of mongoDB collection. /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex. fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document { @@ -113,12 +52,13 @@ impl MongoDbVectorIndex { filter, exact, num_candidates, + path, } = &self.search_params; doc! { "$vectorSearch": { "index": &self.index_name, - "path": "embeddings.vec", + "path": path, "queryVector": &prompt_embedding.vec, "numCandidates": num_candidates.unwrap_or((n * 10) as u32), "limit": n as u32, @@ -139,9 +79,9 @@ impl MongoDbVectorIndex { } } -impl MongoDbVectorIndex { +impl MongoDbVectorIndex { pub fn new( - collection: mongodb::Collection, + collection: mongodb::Collection, model: M, index_name: &str, search_params: SearchParams, @@ -159,17 +99,19 @@ impl MongoDbVectorIndex { /// on each of the fields pub struct SearchParams { filter: mongodb::bson::Document, + path: String, exact: Option, num_candidates: Option, } impl SearchParams { /// Initializes a new `SearchParams` with default values. - pub fn new() -> Self { + pub fn new(path: &str) -> Self { Self { filter: doc! {}, exact: None, num_candidates: None, + path: path.to_string(), } } @@ -199,13 +141,9 @@ impl SearchParams { } } -impl Default for SearchParams { - fn default() -> Self { - Self::new() - } -} - -impl VectorStoreIndex for MongoDbVectorIndex { +impl VectorStoreIndex + for MongoDbVectorIndex +{ async fn top_n Deserialize<'a> + std::marker::Send>( &self, query: &str, From 3438fab667cfd787eb16e70d2266e2930e2c3fa0 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 8 Oct 2024 17:47:07 -0400 Subject: [PATCH 08/47] refactor: refactor rag examples, delete document embedding struct --- rig-core/rig-core-derive/src/embeddable.rs | 9 ++++++- rig-core/src/embeddings.rs | 16 ------------ rig-lancedb/examples/fixtures/lib.rs | 18 +++++-------- .../examples/vector_search_local_ann.rs | 4 +-- rig-lancedb/examples/vector_search_s3_ann.rs | 4 +-- rig-mongodb/examples/vector_search_mongodb.rs | 26 ++++++++++--------- 6 files changed, 32 insertions(+), 45 deletions(-) diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index e2c34afe..ec52d952 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -7,6 +7,7 @@ use syn::{ const EMBED: &str = "embed"; const EMBED_WITH: &str = "embed_with"; +const VEC_TYPE: &str = "Vec"; pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { let name = &input.ident; @@ -92,7 +93,13 @@ fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { fn embed_kind(field: &syn::Field) -> Result { match &field.ty { - syn::Type::Array(_) => parse_str("ManyEmbedding"), + syn::Type::Path(path) => { + if path.path.segments.first().unwrap().ident == VEC_TYPE { + parse_str("ManyEmbedding") + } else { + parse_str("SingleEmbedding") + } + }, _ => parse_str("SingleEmbedding"), } } diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index b044b406..458c6eb3 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -131,22 +131,6 @@ impl Embedding { } } -/// Struct that holds a document and its embeddings. -/// -/// The struct is designed to model any kind of documents that can be serialized to JSON -/// (including a simple string). -/// -/// Moreover, it can hold multiple embeddings for the same document, thus allowing a -/// large document to be retrieved from a query that matches multiple smaller and -/// distinct text documents. For example, if the document is a textbook, a summary of -/// each chapter could serve as the book's embeddings. -#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)] -pub struct DocumentEmbeddings { - #[serde(rename = "_id")] - pub id: String, - pub document: serde_json::Value, - pub embeddings: Vec, -} pub trait EmbeddingKind {} pub struct SingleEmbedding; impl EmbeddingKind for SingleEmbedding {} diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 322c0a6a..e8ebead9 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -6,28 +6,22 @@ use rig::embeddings::Embedding; use rig_derive::Embed; use serde::Deserialize; -#[derive(Embed, Clone)] +#[derive(Embed, Clone, Deserialize, Debug)] pub struct FakeDefinition { id: String, #[embed] definition: String, } -#[derive(Deserialize, Debug)] -pub struct VectorSearchResult { - pub id: String, - pub definition: String, -} - pub fn fake_definitions() -> Vec { vec![ FakeDefinition { id: "doc0".to_string(), - definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string() + definition: "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.".to_string() }, FakeDefinition { id: "doc1".to_string(), - definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string() + definition: "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.".to_string() }, FakeDefinition { id: "doc2".to_string(), @@ -39,7 +33,7 @@ pub fn fake_definitions() -> Vec { pub fn fake_definition(id: String) -> FakeDefinition { FakeDefinition { id, - definition: "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string() + definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() } } @@ -59,7 +53,7 @@ pub fn schema(dims: usize) -> Schema { ])) } -// Convert DocumentEmbeddings objects to a RecordBatch. +// Convert FakeDefinition objects and their embedding to a RecordBatch. pub fn as_record_batch( records: Vec<(FakeDefinition, Embedding)>, dims: usize, @@ -74,7 +68,7 @@ pub fn as_record_batch( let definition = StringArray::from_iter_values( records .iter() - .map(|(_, Embedding { document, .. })| document) + .map(|(FakeDefinition { definition, .. }, _)| definition) .collect::>(), ); diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 3ac925bb..af82ad49 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definition, fake_definitions, schema, VectorSearchResult}; +use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDefinition}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ @@ -62,7 +62,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n::("My boss says I zindle too much, what does that mean?", 1) + .top_n::("My boss says I zindle too much, what does that mean?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 358a1755..17c4cd7f 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definition, fake_definitions, schema, VectorSearchResult}; +use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, @@ -74,7 +74,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n::("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) + .top_n::("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 4ddf1f99..18e0873a 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,5 +1,5 @@ use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; -use rig::{embeddings::Embedding, providers::openai::TEXT_EMBEDDING_ADA_002}; +use rig::providers::openai::TEXT_EMBEDDING_ADA_002; use rig_derive::Embed; use serde::{Deserialize, Serialize}; use std::env; @@ -9,18 +9,22 @@ use rig::{ }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; -#[derive(Embed, Clone)] +// Shape of data that needs to be RAG'ed. +// The definition field will be used to generate embeddings. +#[derive(Embed, Clone, Deserialize, Debug)] struct FakeDefinition { id: String, #[embed] definition: String, } -#[derive(Serialize, Debug, Deserialize)] +// Shape of the document to be stored in MongoDB. +#[derive(Serialize, Debug)] struct Document { #[serde(rename = "_id")] id: String, - definition: Embedding, + definition: String, + embedding: Vec, } #[tokio::main] @@ -69,9 +73,10 @@ async fn main() -> Result<(), anyhow::Error> { let mongo_documents = embeddings .iter() - .map(|(FakeDefinition { id, .. }, embedding)| Document { + .map(|(FakeDefinition { id, definition }, embedding)| Document { id: id.clone(), - definition: embedding.clone(), + definition: definition.clone(), + embedding: embedding.vec.clone(), }) .collect::>(); @@ -87,16 +92,13 @@ async fn main() -> Result<(), anyhow::Error> { let index = vector_store.index( model, "definitions_vector_index", - SearchParams::new("definition.vec"), + SearchParams::new("embedding"), ); // Query the index let results = index - .top_n::("What is a linglingdong?", 1) - .await? - .into_iter() - .map(|(score, id, doc)| (score, id, doc.definition.document)) - .collect::>(); + .top_n::("What is a linglingdong?", 1) + .await?; println!("Results: {:?}", results); From 597e6c3e7e82b91a2df179cb04c66f29d83b1696 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 8 Oct 2024 21:45:43 -0400 Subject: [PATCH 09/47] feat: remove document embedding from in memory store --- rig-core/examples/rag.rs | 48 ++++- rig-core/examples/vector_search.rs | 59 +++++- rig-core/examples/vector_search_cohere.rs | 57 +++++- rig-core/rig-core-derive/src/embeddable.rs | 2 +- rig-core/src/vector_store/in_memory_store.rs | 179 ++++-------------- rig-core/src/vector_store/mod.rs | 18 +- rig-mongodb/examples/vector_search_mongodb.rs | 2 +- 7 files changed, 191 insertions(+), 174 deletions(-) diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 3abd8ee9..55a99e63 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -6,6 +6,15 @@ use rig::{ providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, }; +use rig_derive::Embed; +use serde::Serialize; + +#[derive(Embed, Clone, Serialize, Eq, PartialEq, Default)] +struct FakeDefinition { + id: String, + #[embed] + definitions: Vec, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -18,14 +27,45 @@ async fn main() -> Result<(), anyhow::Error> { // Create vector store, compute embeddings and load them in the store let mut vector_store = InMemoryVectorStore::default(); + let fake_definitions = vec![ + FakeDefinition { + id: "doc0".to_string(), + definitions: vec![ + "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + "Definition of a *flurbo*: A unit of currency used in a bizarre or fantastical world, often associated with eccentric societies or sci-fi settings.".to_string() + ] + }, + FakeDefinition { + id: "doc1".to_string(), + definitions: vec![ + "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "Definition of a *glarb-glarb*: A mysterious, bubbling substance often found in swamps, alien planets, or under mysterious circumstances.".to_string() + ] + }, + FakeDefinition { + id: "doc2".to_string(), + definitions: vec![ + "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + ] + } + ]; + let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .documents(fake_definitions) .build() .await?; - vector_store.add_documents(embeddings).await?; + vector_store + .add_documents( + embeddings + .into_iter() + .enumerate() + .map(|(i, (fake_definition, embeddings))| { + (format!("doc{i}"), fake_definition, embeddings) + }) + .collect(), + ) + .await?; // Create vector store index let index = vector_store.index(embedding_model); diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index e0d68bef..095ca296 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,10 +1,19 @@ use std::env; use rig::{ - embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex}, + vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, }; +use rig_derive::Embed; +use serde::{Deserialize, Serialize}; + +#[derive(Embed, Clone, Serialize, Default, Eq, PartialEq, Deserialize, Debug)] +struct FakeDefinition { + id: String, + #[embed] + definitions: Vec, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -14,20 +23,54 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + let fake_definitions = vec![ + FakeDefinition { + id: "doc0".to_string(), + definitions: vec![ + "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + "Definition of a *flurbo*: A unit of currency used in a bizarre or fantastical world, often associated with eccentric societies or sci-fi settings.".to_string() + ] + }, + FakeDefinition { + id: "doc1".to_string(), + definitions: vec![ + "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "Definition of a *glarb-glarb*: A mysterious, bubbling substance often found in swamps, alien planets, or under mysterious circumstances.".to_string() + ] + }, + FakeDefinition { + id: "doc2".to_string(), + definitions: vec![ + "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + ] + } + ]; + let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .documents(fake_definitions) .build() .await?; - let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?; + let mut store = InMemoryVectorStore::default(); + store + .add_documents( + embeddings + .into_iter() + .enumerate() + .map(|(i, (fake_definition, embeddings))| { + (format!("doc{i}"), fake_definition, embeddings) + }) + .collect(), + ) + .await?; + + let index = store.index(model); let results = index - .top_n::("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.document)) + .map(|(score, id, doc)| (score, id, doc)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index a49ac231..6be094b4 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,10 +1,19 @@ use std::env; use rig::{ - embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, + embeddings::EmbeddingsBuilder, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, }; +use rig_derive::Embed; +use serde::{Deserialize, Serialize}; + +#[derive(Embed, Clone, Serialize, Default, Eq, PartialEq, Deserialize, Debug)] +struct FakeDefinition { + id: String, + #[embed] + definitions: Vec, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -15,24 +24,54 @@ async fn main() -> Result<(), anyhow::Error> { let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document"); let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query"); - let mut vector_store = InMemoryVectorStore::default(); + let fake_definitions = vec![ + FakeDefinition { + id: "doc0".to_string(), + definitions: vec![ + "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + "Definition of a *flurbo*: A unit of currency used in a bizarre or fantastical world, often associated with eccentric societies or sci-fi settings.".to_string() + ] + }, + FakeDefinition { + id: "doc1".to_string(), + definitions: vec![ + "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "Definition of a *glarb-glarb*: A mysterious, bubbling substance often found in swamps, alien planets, or under mysterious circumstances.".to_string() + ] + }, + FakeDefinition { + id: "doc2".to_string(), + definitions: vec![ + "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + ] + } + ]; let embeddings = EmbeddingsBuilder::new(document_model) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .documents(fake_definitions) .build() .await?; - vector_store.add_documents(embeddings).await?; + let mut store = InMemoryVectorStore::default(); + store + .add_documents( + embeddings + .into_iter() + .enumerate() + .map(|(i, (fake_definition, embeddings))| { + (format!("doc{i}"), fake_definition, embeddings) + }) + .collect(), + ) + .await?; - let index = vector_store.index(search_model); + let index = store.index(search_model); let results = index - .top_n::("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.document)) + .map(|(score, id, doc)| (score, id, doc)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index ec52d952..8546f84a 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -99,7 +99,7 @@ fn embed_kind(field: &syn::Field) -> Result { } else { parse_str("SingleEmbedding") } - }, + } _ => parse_str("SingleEmbedding"), } } diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index a5db505f..5ed98671 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -8,27 +8,26 @@ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; use super::{VectorStore, VectorStoreError, VectorStoreIndex}; -use crate::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel, EmbeddingsBuilder}; +use crate::embeddings::{Embedding, EmbeddingModel}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. -#[derive(Clone, Default, Deserialize, Serialize)] -pub struct InMemoryVectorStore { +#[derive(Clone, Default)] +pub struct InMemoryVectorStore { /// The embeddings are stored in a HashMap with the document ID as the key. - embeddings: HashMap, + embeddings: HashMap)>, } -impl InMemoryVectorStore { +impl InMemoryVectorStore { /// Implement vector search on InMemoryVectorStore. /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for InMemoryVectorStore. - fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { + fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { // Sort documents by best embedding distance - let mut docs: EmbeddingRanking = BinaryHeap::new(); + let mut docs = BinaryHeap::new(); - for (id, doc_embeddings) in self.embeddings.iter() { + for (id, (doc, embeddings)) in self.embeddings.iter() { // Get the best context for the document given the prompt - if let Some((distance, embed_doc)) = doc_embeddings - .embeddings + if let Some((distance, embed_doc)) = embeddings .iter() .map(|embedding| { ( @@ -38,12 +37,7 @@ impl InMemoryVectorStore { }) .min_by(|a, b| a.0.cmp(&b.0)) { - docs.push(Reverse(RankingItem( - distance, - id, - doc_embeddings, - embed_doc, - ))); + docs.push(Reverse(RankingItem(distance, id, doc, embed_doc))); }; // If the heap size exceeds n, pop the least old element. @@ -67,73 +61,51 @@ impl InMemoryVectorStore { /// RankingItem(distance, document_id, document, embed_doc) #[derive(Eq, PartialEq)] -struct RankingItem<'a>( - OrderedFloat, - &'a String, - &'a DocumentEmbeddings, - &'a String, -); - -impl Ord for RankingItem<'_> { +struct RankingItem<'a, D: Serialize>(OrderedFloat, &'a String, &'a D, &'a String); + +impl Ord for RankingItem<'_, D> { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.0.cmp(&other.0) } } -impl PartialOrd for RankingItem<'_> { +impl PartialOrd for RankingItem<'_, D> { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -type EmbeddingRanking<'a> = BinaryHeap>>; +type EmbeddingRanking<'a, D> = BinaryHeap>>; -impl VectorStore for InMemoryVectorStore { +impl VectorStore for InMemoryVectorStore { type Q = (); async fn add_documents( &mut self, - documents: Vec, + documents: Vec<(String, D, Vec)>, ) -> Result<(), VectorStoreError> { - for doc in documents { - self.embeddings.insert(doc.id.clone(), doc); + for (id, doc, embeddings) in documents { + self.embeddings.insert(id, (doc, embeddings)); } Ok(()) } - async fn get_document Deserialize<'a>>( - &self, - id: &str, - ) -> Result, VectorStoreError> { - Ok(self - .embeddings - .get(id) - .map(|document| serde_json::from_value(document.document.clone())) - .transpose()?) - } - - async fn get_document_embeddings( - &self, - id: &str, - ) -> Result, VectorStoreError> { - Ok(self.embeddings.get(id).cloned()) + async fn get_document_embeddings(&self, id: &str) -> Result, VectorStoreError> { + Ok(self.embeddings.get(id).cloned().map(|(doc, _)| doc)) } - async fn get_document_by_query( - &self, - _query: Self::Q, - ) -> Result, VectorStoreError> { + async fn get_document_by_query(&self, _query: Self::Q) -> Result, VectorStoreError> { Ok(None) } } -impl InMemoryVectorStore { - pub fn index(self, model: M) -> InMemoryVectorIndex { +impl InMemoryVectorStore { + pub fn index(self, model: M) -> InMemoryVectorIndex { InMemoryVectorIndex::new(model, self) } - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator))> { self.embeddings.iter() } @@ -144,54 +116,19 @@ impl InMemoryVectorStore { pub fn is_empty(&self) -> bool { self.embeddings.is_empty() } - - /// Uitilty method to create an InMemoryVectorStore from a list of embeddings. - pub async fn from_embeddings( - embeddings: Vec, - ) -> Result { - let mut store = Self::default(); - store.add_documents(embeddings).await?; - Ok(store) - } - - /// Create an InMemoryVectorStore from a list of documents. - /// The documents are serialized to JSON and embedded using the provided embedding model. - /// The resulting embeddings are stored in an InMemoryVectorStore created by the method. - pub async fn from_documents( - embedding_model: M, - documents: &[(String, T)], - ) -> Result { - let embeddings = documents - .iter() - .fold( - EmbeddingsBuilder::new(embedding_model), - |builder, (id, doc)| { - builder.json_document( - id, - serde_json::to_value(doc).expect("Document should be serializable"), - vec![serde_json::to_string(doc).expect("Document should be serializable")], - ) - }, - ) - .build() - .await?; - - let store = Self::from_embeddings(embeddings).await?; - Ok(store) - } } -pub struct InMemoryVectorIndex { +pub struct InMemoryVectorIndex { model: M, - pub store: InMemoryVectorStore, + pub store: InMemoryVectorStore, } -impl InMemoryVectorIndex { - pub fn new(model: M, store: InMemoryVectorStore) -> Self { +impl InMemoryVectorIndex { + pub fn new(model: M, store: InMemoryVectorStore) -> Self { Self { model, store } } - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator))> { self.store.iter() } @@ -202,49 +139,11 @@ impl InMemoryVectorIndex { pub fn is_empty(&self) -> bool { self.store.is_empty() } - - /// Create an InMemoryVectorIndex from a list of documents. - /// The documents are serialized to JSON and embedded using the provided embedding model. - /// The resulting embeddings are stored in an InMemoryVectorStore created by the method. - /// The InMemoryVectorIndex is then created from the store and the provided query model. - pub async fn from_documents( - embedding_model: M, - query_model: M, - documents: &[(String, T)], - ) -> Result { - let mut store = InMemoryVectorStore::default(); - - let embeddings = documents - .iter() - .fold( - EmbeddingsBuilder::new(embedding_model), - |builder, (id, doc)| { - builder.json_document( - id, - serde_json::to_value(doc).expect("Document should be serializable"), - vec![serde_json::to_string(doc).expect("Document should be serializable")], - ) - }, - ) - .build() - .await?; - - store.add_documents(embeddings).await?; - Ok(store.index(query_model)) - } - - /// Utility method to create an InMemoryVectorIndex from a list of embeddings - /// and an embedding model. - pub async fn from_embeddings( - query_model: M, - embeddings: Vec, - ) -> Result { - let store = InMemoryVectorStore::from_embeddings(embeddings).await?; - Ok(store.index(query_model)) - } } -impl VectorStoreIndex for InMemoryVectorIndex { +impl VectorStoreIndex + for InMemoryVectorIndex +{ async fn top_n Deserialize<'a>>( &self, query: &str, @@ -256,12 +155,14 @@ impl VectorStoreIndex for InMemoryVectorI // Return n best docs.into_iter() - .map(|Reverse(RankingItem(distance, _, doc, _))| { - let doc_value = serde_json::to_value(doc).map_err(VectorStoreError::JsonError)?; + .map(|Reverse(RankingItem(distance, id, doc, _))| { Ok(( distance.0, - doc.id.clone(), - serde_json::from_value(doc_value).map_err(VectorStoreError::JsonError)?, + id.clone(), + serde_json::from_str( + &serde_json::to_string(doc).map_err(VectorStoreError::JsonError)?, + ) + .map_err(VectorStoreError::JsonError)?, )) }) .collect::, _>>() @@ -278,7 +179,7 @@ impl VectorStoreIndex for InMemoryVectorI // Return n best docs.into_iter() - .map(|Reverse(RankingItem(distance, _, doc, _))| Ok((distance.0, doc.id.clone()))) + .map(|Reverse(RankingItem(distance, id, _, _))| Ok((distance.0, id.clone()))) .collect::, _>>() } } diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index b07d348a..8f89d5f1 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -1,8 +1,8 @@ use futures::future::BoxFuture; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::Value; -use crate::embeddings::{DocumentEmbeddings, EmbeddingError}; +use crate::embeddings::{Embedding, EmbeddingError}; pub mod in_memory_store; @@ -20,33 +20,27 @@ pub enum VectorStoreError { } /// Trait for vector stores -pub trait VectorStore: Send + Sync { +pub trait VectorStore: Send + Sync { /// Query type for the vector store type Q; /// Add a list of documents to the vector store fn add_documents( &mut self, - documents: Vec, + documents: Vec<(String, D, Vec)>, ) -> impl std::future::Future> + Send; /// Get the embeddings of a document by its id fn get_document_embeddings( &self, id: &str, - ) -> impl std::future::Future, VectorStoreError>> + Send; - - /// Get the document by its id and deserialize it into the given type - fn get_document Deserialize<'a>>( - &self, - id: &str, - ) -> impl std::future::Future, VectorStoreError>> + Send; + ) -> impl std::future::Future, VectorStoreError>> + Send; /// Get the document by a query and deserialize it into the given type fn get_document_by_query( &self, query: Self::Q, - ) -> impl std::future::Future, VectorStoreError>> + Send; + ) -> impl std::future::Future, VectorStoreError>> + Send; } /// Trait for vector store indexes diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 18e0873a..ece3e093 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -18,7 +18,7 @@ struct FakeDefinition { definition: String, } -// Shape of the document to be stored in MongoDB. +// Shape of the document to be stored in MongoDB, with embeddings. #[derive(Serialize, Debug)] struct Document { #[serde(rename = "_id")] From 6c7ab8d6f344a773add6aa6156091e9033e8725c Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 9 Oct 2024 13:39:13 -0400 Subject: [PATCH 10/47] revert: revert vector store to main --- rig-core/examples/rag.rs | 48 +---- rig-core/examples/vector_search.rs | 59 +----- rig-core/examples/vector_search_cohere.rs | 57 +----- rig-core/src/vector_store/in_memory_store.rs | 179 ++++++++++++++----- rig-core/src/vector_store/mod.rs | 18 +- 5 files changed, 172 insertions(+), 189 deletions(-) diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 55a99e63..3abd8ee9 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -6,15 +6,6 @@ use rig::{ providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, }; -use rig_derive::Embed; -use serde::Serialize; - -#[derive(Embed, Clone, Serialize, Eq, PartialEq, Default)] -struct FakeDefinition { - id: String, - #[embed] - definitions: Vec, -} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -27,45 +18,14 @@ async fn main() -> Result<(), anyhow::Error> { // Create vector store, compute embeddings and load them in the store let mut vector_store = InMemoryVectorStore::default(); - let fake_definitions = vec![ - FakeDefinition { - id: "doc0".to_string(), - definitions: vec![ - "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), - "Definition of a *flurbo*: A unit of currency used in a bizarre or fantastical world, often associated with eccentric societies or sci-fi settings.".to_string() - ] - }, - FakeDefinition { - id: "doc1".to_string(), - definitions: vec![ - "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "Definition of a *glarb-glarb*: A mysterious, bubbling substance often found in swamps, alien planets, or under mysterious circumstances.".to_string() - ] - }, - FakeDefinition { - id: "doc2".to_string(), - definitions: vec![ - "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() - ] - } - ]; - let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(fake_definitions) + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") .build() .await?; - vector_store - .add_documents( - embeddings - .into_iter() - .enumerate() - .map(|(i, (fake_definition, embeddings))| { - (format!("doc{i}"), fake_definition, embeddings) - }) - .collect(), - ) - .await?; + vector_store.add_documents(embeddings).await?; // Create vector store index let index = vector_store.index(embedding_model); diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 095ca296..e0d68bef 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,19 +1,10 @@ use std::env; use rig::{ - embeddings::EmbeddingsBuilder, + embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, + vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex}, }; -use rig_derive::Embed; -use serde::{Deserialize, Serialize}; - -#[derive(Embed, Clone, Serialize, Default, Eq, PartialEq, Deserialize, Debug)] -struct FakeDefinition { - id: String, - #[embed] - definitions: Vec, -} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -23,54 +14,20 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - let fake_definitions = vec![ - FakeDefinition { - id: "doc0".to_string(), - definitions: vec![ - "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), - "Definition of a *flurbo*: A unit of currency used in a bizarre or fantastical world, often associated with eccentric societies or sci-fi settings.".to_string() - ] - }, - FakeDefinition { - id: "doc1".to_string(), - definitions: vec![ - "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "Definition of a *glarb-glarb*: A mysterious, bubbling substance often found in swamps, alien planets, or under mysterious circumstances.".to_string() - ] - }, - FakeDefinition { - id: "doc2".to_string(), - definitions: vec![ - "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() - ] - } - ]; - let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions) + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") .build() .await?; - let mut store = InMemoryVectorStore::default(); - store - .add_documents( - embeddings - .into_iter() - .enumerate() - .map(|(i, (fake_definition, embeddings))| { - (format!("doc{i}"), fake_definition, embeddings) - }) - .collect(), - ) - .await?; - - let index = store.index(model); + let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?; let results = index - .top_n::("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc)) + .map(|(score, id, doc)| (score, id, doc.document)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 6be094b4..a49ac231 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,19 +1,10 @@ use std::env; use rig::{ - embeddings::EmbeddingsBuilder, + embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, }; -use rig_derive::Embed; -use serde::{Deserialize, Serialize}; - -#[derive(Embed, Clone, Serialize, Default, Eq, PartialEq, Deserialize, Debug)] -struct FakeDefinition { - id: String, - #[embed] - definitions: Vec, -} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -24,54 +15,24 @@ async fn main() -> Result<(), anyhow::Error> { let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document"); let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query"); - let fake_definitions = vec![ - FakeDefinition { - id: "doc0".to_string(), - definitions: vec![ - "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), - "Definition of a *flurbo*: A unit of currency used in a bizarre or fantastical world, often associated with eccentric societies or sci-fi settings.".to_string() - ] - }, - FakeDefinition { - id: "doc1".to_string(), - definitions: vec![ - "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "Definition of a *glarb-glarb*: A mysterious, bubbling substance often found in swamps, alien planets, or under mysterious circumstances.".to_string() - ] - }, - FakeDefinition { - id: "doc2".to_string(), - definitions: vec![ - "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() - ] - } - ]; + let mut vector_store = InMemoryVectorStore::default(); let embeddings = EmbeddingsBuilder::new(document_model) - .documents(fake_definitions) + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") .build() .await?; - let mut store = InMemoryVectorStore::default(); - store - .add_documents( - embeddings - .into_iter() - .enumerate() - .map(|(i, (fake_definition, embeddings))| { - (format!("doc{i}"), fake_definition, embeddings) - }) - .collect(), - ) - .await?; + vector_store.add_documents(embeddings).await?; - let index = store.index(search_model); + let index = vector_store.index(search_model); let results = index - .top_n::("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc)) + .map(|(score, id, doc)| (score, id, doc.document)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 5ed98671..a5db505f 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -8,26 +8,27 @@ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; use super::{VectorStore, VectorStoreError, VectorStoreIndex}; -use crate::embeddings::{Embedding, EmbeddingModel}; +use crate::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel, EmbeddingsBuilder}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. -#[derive(Clone, Default)] -pub struct InMemoryVectorStore { +#[derive(Clone, Default, Deserialize, Serialize)] +pub struct InMemoryVectorStore { /// The embeddings are stored in a HashMap with the document ID as the key. - embeddings: HashMap)>, + embeddings: HashMap, } -impl InMemoryVectorStore { +impl InMemoryVectorStore { /// Implement vector search on InMemoryVectorStore. /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for InMemoryVectorStore. - fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { + fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { // Sort documents by best embedding distance - let mut docs = BinaryHeap::new(); + let mut docs: EmbeddingRanking = BinaryHeap::new(); - for (id, (doc, embeddings)) in self.embeddings.iter() { + for (id, doc_embeddings) in self.embeddings.iter() { // Get the best context for the document given the prompt - if let Some((distance, embed_doc)) = embeddings + if let Some((distance, embed_doc)) = doc_embeddings + .embeddings .iter() .map(|embedding| { ( @@ -37,7 +38,12 @@ impl InMemoryVectorStore { }) .min_by(|a, b| a.0.cmp(&b.0)) { - docs.push(Reverse(RankingItem(distance, id, doc, embed_doc))); + docs.push(Reverse(RankingItem( + distance, + id, + doc_embeddings, + embed_doc, + ))); }; // If the heap size exceeds n, pop the least old element. @@ -61,51 +67,73 @@ impl InMemoryVectorStore { /// RankingItem(distance, document_id, document, embed_doc) #[derive(Eq, PartialEq)] -struct RankingItem<'a, D: Serialize>(OrderedFloat, &'a String, &'a D, &'a String); - -impl Ord for RankingItem<'_, D> { +struct RankingItem<'a>( + OrderedFloat, + &'a String, + &'a DocumentEmbeddings, + &'a String, +); + +impl Ord for RankingItem<'_> { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.0.cmp(&other.0) } } -impl PartialOrd for RankingItem<'_, D> { +impl PartialOrd for RankingItem<'_> { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -type EmbeddingRanking<'a, D> = BinaryHeap>>; +type EmbeddingRanking<'a> = BinaryHeap>>; -impl VectorStore for InMemoryVectorStore { +impl VectorStore for InMemoryVectorStore { type Q = (); async fn add_documents( &mut self, - documents: Vec<(String, D, Vec)>, + documents: Vec, ) -> Result<(), VectorStoreError> { - for (id, doc, embeddings) in documents { - self.embeddings.insert(id, (doc, embeddings)); + for doc in documents { + self.embeddings.insert(doc.id.clone(), doc); } Ok(()) } - async fn get_document_embeddings(&self, id: &str) -> Result, VectorStoreError> { - Ok(self.embeddings.get(id).cloned().map(|(doc, _)| doc)) + async fn get_document Deserialize<'a>>( + &self, + id: &str, + ) -> Result, VectorStoreError> { + Ok(self + .embeddings + .get(id) + .map(|document| serde_json::from_value(document.document.clone())) + .transpose()?) + } + + async fn get_document_embeddings( + &self, + id: &str, + ) -> Result, VectorStoreError> { + Ok(self.embeddings.get(id).cloned()) } - async fn get_document_by_query(&self, _query: Self::Q) -> Result, VectorStoreError> { + async fn get_document_by_query( + &self, + _query: Self::Q, + ) -> Result, VectorStoreError> { Ok(None) } } -impl InMemoryVectorStore { - pub fn index(self, model: M) -> InMemoryVectorIndex { +impl InMemoryVectorStore { + pub fn index(self, model: M) -> InMemoryVectorIndex { InMemoryVectorIndex::new(model, self) } - pub fn iter(&self) -> impl Iterator))> { + pub fn iter(&self) -> impl Iterator { self.embeddings.iter() } @@ -116,19 +144,54 @@ impl InMemoryVectorStore { pub fn is_empty(&self) -> bool { self.embeddings.is_empty() } + + /// Uitilty method to create an InMemoryVectorStore from a list of embeddings. + pub async fn from_embeddings( + embeddings: Vec, + ) -> Result { + let mut store = Self::default(); + store.add_documents(embeddings).await?; + Ok(store) + } + + /// Create an InMemoryVectorStore from a list of documents. + /// The documents are serialized to JSON and embedded using the provided embedding model. + /// The resulting embeddings are stored in an InMemoryVectorStore created by the method. + pub async fn from_documents( + embedding_model: M, + documents: &[(String, T)], + ) -> Result { + let embeddings = documents + .iter() + .fold( + EmbeddingsBuilder::new(embedding_model), + |builder, (id, doc)| { + builder.json_document( + id, + serde_json::to_value(doc).expect("Document should be serializable"), + vec![serde_json::to_string(doc).expect("Document should be serializable")], + ) + }, + ) + .build() + .await?; + + let store = Self::from_embeddings(embeddings).await?; + Ok(store) + } } -pub struct InMemoryVectorIndex { +pub struct InMemoryVectorIndex { model: M, - pub store: InMemoryVectorStore, + pub store: InMemoryVectorStore, } -impl InMemoryVectorIndex { - pub fn new(model: M, store: InMemoryVectorStore) -> Self { +impl InMemoryVectorIndex { + pub fn new(model: M, store: InMemoryVectorStore) -> Self { Self { model, store } } - pub fn iter(&self) -> impl Iterator))> { + pub fn iter(&self) -> impl Iterator { self.store.iter() } @@ -139,11 +202,49 @@ impl InMemoryVectorIndex { pub fn is_empty(&self) -> bool { self.store.is_empty() } + + /// Create an InMemoryVectorIndex from a list of documents. + /// The documents are serialized to JSON and embedded using the provided embedding model. + /// The resulting embeddings are stored in an InMemoryVectorStore created by the method. + /// The InMemoryVectorIndex is then created from the store and the provided query model. + pub async fn from_documents( + embedding_model: M, + query_model: M, + documents: &[(String, T)], + ) -> Result { + let mut store = InMemoryVectorStore::default(); + + let embeddings = documents + .iter() + .fold( + EmbeddingsBuilder::new(embedding_model), + |builder, (id, doc)| { + builder.json_document( + id, + serde_json::to_value(doc).expect("Document should be serializable"), + vec![serde_json::to_string(doc).expect("Document should be serializable")], + ) + }, + ) + .build() + .await?; + + store.add_documents(embeddings).await?; + Ok(store.index(query_model)) + } + + /// Utility method to create an InMemoryVectorIndex from a list of embeddings + /// and an embedding model. + pub async fn from_embeddings( + query_model: M, + embeddings: Vec, + ) -> Result { + let store = InMemoryVectorStore::from_embeddings(embeddings).await?; + Ok(store.index(query_model)) + } } -impl VectorStoreIndex - for InMemoryVectorIndex -{ +impl VectorStoreIndex for InMemoryVectorIndex { async fn top_n Deserialize<'a>>( &self, query: &str, @@ -155,14 +256,12 @@ impl Vec // Return n best docs.into_iter() - .map(|Reverse(RankingItem(distance, id, doc, _))| { + .map(|Reverse(RankingItem(distance, _, doc, _))| { + let doc_value = serde_json::to_value(doc).map_err(VectorStoreError::JsonError)?; Ok(( distance.0, - id.clone(), - serde_json::from_str( - &serde_json::to_string(doc).map_err(VectorStoreError::JsonError)?, - ) - .map_err(VectorStoreError::JsonError)?, + doc.id.clone(), + serde_json::from_value(doc_value).map_err(VectorStoreError::JsonError)?, )) }) .collect::, _>>() @@ -179,7 +278,7 @@ impl Vec // Return n best docs.into_iter() - .map(|Reverse(RankingItem(distance, id, _, _))| Ok((distance.0, id.clone()))) + .map(|Reverse(RankingItem(distance, _, doc, _))| Ok((distance.0, doc.id.clone()))) .collect::, _>>() } } diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 8f89d5f1..b07d348a 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -1,8 +1,8 @@ use futures::future::BoxFuture; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use serde_json::Value; -use crate::embeddings::{Embedding, EmbeddingError}; +use crate::embeddings::{DocumentEmbeddings, EmbeddingError}; pub mod in_memory_store; @@ -20,27 +20,33 @@ pub enum VectorStoreError { } /// Trait for vector stores -pub trait VectorStore: Send + Sync { +pub trait VectorStore: Send + Sync { /// Query type for the vector store type Q; /// Add a list of documents to the vector store fn add_documents( &mut self, - documents: Vec<(String, D, Vec)>, + documents: Vec, ) -> impl std::future::Future> + Send; /// Get the embeddings of a document by its id fn get_document_embeddings( &self, id: &str, - ) -> impl std::future::Future, VectorStoreError>> + Send; + ) -> impl std::future::Future, VectorStoreError>> + Send; + + /// Get the document by its id and deserialize it into the given type + fn get_document Deserialize<'a>>( + &self, + id: &str, + ) -> impl std::future::Future, VectorStoreError>> + Send; /// Get the document by a query and deserialize it into the given type fn get_document_by_query( &self, query: Self::Q, - ) -> impl std::future::Future, VectorStoreError>> + Send; + ) -> impl std::future::Future, VectorStoreError>> + Send; } /// Trait for vector store indexes From bb712e3fdabfffbf4a4e6c35fb2eec951ff5c83f Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 9 Oct 2024 15:49:00 -0400 Subject: [PATCH 11/47] docs: update emebddings builder docstrings --- rig-core/src/embeddings.rs | 166 ++++++++++++++++++++++++++++++------- 1 file changed, 136 insertions(+), 30 deletions(-) diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index 458c6eb3..d68ae565 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -8,31 +8,66 @@ //! struct, which allows users to build collections of document embeddings using different embedding //! models and document sources. //! -//! The module also defines the [Embedding] struct, which represents a single document embedding, -//! and the [DocumentEmbeddings] struct, which represents a document along with its associated -//! embeddings. These structs are used to store and manipulate collections of document embeddings. +//! The module also defines the [Embedding] struct, which represents a single document embedding. +//! +//! The module also defines the [Embeddable] trait, which represents types that can be embedded. +//! Only types that implement the Embeddable trait can be used with the EmbeddingsBuilder. //! //! Finally, the module defines the [EmbeddingError] enum, which represents various errors that //! can occur during embedding generation or processing. //! //! # Example //! ```rust -//! use rig::providers::openai::{Client, self}; -//! use rig::embeddings::{EmbeddingModel, EmbeddingsBuilder}; +//! use std::env; +//! +//! use rig::{ +//! embeddings::EmbeddingsBuilder, +//! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, +//! }; +//! use rig_derive::Embed; //! -//! // Initialize the OpenAI client -//! let openai = Client::new("your-openai-api-key"); +//! #[derive(Embed)] +//! struct FakeDefinition { +//! id: String, +//! word: String, +//! #[embed] +//! definitions: Vec, +//! } +//! // Create OpenAI client +//! let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); +//! let openai_client = Client::new(&openai_api_key); //! -//! // Create an instance of the `text-embedding-ada-002` model -//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002); +//! let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); //! -//! // Create an embeddings builder and add documents -//! let embeddings = EmbeddingsBuilder::new(embedding_model) -//! .simple_document("doc1", "This is the first document.") -//! .simple_document("doc2", "This is the second document.") -//! .build() -//! .await -//! .expect("Failed to build embeddings."); +//! let embeddings = EmbeddingsBuilder::new(model.clone()) +//! .documents(vec![ +//! FakeDefinition { +//! id: "doc0".to_string(), +//! word: "flurbo".to_string(), +//! definitions: vec![ +//! "A green alien that lives on cold planets.".to_string(), +//! "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() +//! ] +//! }, +//! FakeDefinition { +//! id: "doc1".to_string(), +//! word: "glarb-glarb".to_string(), +//! definitions: vec![ +//! "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), +//! "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() +//! ] +//! }, +//! FakeDefinition { +//! id: "doc2".to_string(), +//! word: "linglingdong".to_string(), +//! definitions: vec![ +//! "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), +//! "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() +//! ] +//! }, +//! ]) +//! .build() +//! .await?; //! //! // Use the generated embeddings //! // ... @@ -102,7 +137,7 @@ pub trait EmbeddingModel: Clone + Sync + Send { /// Struct that holds a single document and its embedding. #[derive(Clone, Default, Deserialize, Serialize, Debug)] pub struct Embedding { - /// The document that was embedded + /// The document that was embedded. Used for debugging. pub document: String, /// The embedding vector pub vec: Vec, @@ -159,6 +194,7 @@ impl, K: EmbeddingKind> EmbeddingsBui } } + /// Add a document that implements `Embeddable` to the builder. pub fn document(mut self, document: D) -> Self { let embed_targets = document.embeddable(); @@ -166,6 +202,7 @@ impl, K: EmbeddingKind> EmbeddingsBui self } + /// Add many documents that implement `Embeddable` to the builder. pub fn documents(mut self, documents: Vec) -> EmbeddingsBuilder { documents.into_iter().for_each(|doc| { let embed_targets = doc.embeddable(); @@ -180,7 +217,11 @@ impl, K: EmbeddingKind> EmbeddingsBui impl EmbeddingsBuilder { + /// Generate embeddings for all documents in the builder. + /// The method only applies when documents in the builder each contain multiple embedding targets. + /// Returns a vector of tuples, where the first element is the document and the second element is the vector of embeddings. pub async fn build(&self) -> Result)>, EmbeddingError> { + // Use this for reference later to merge a document back with its embeddings. let documents_map = self .documents .clone() @@ -189,22 +230,19 @@ impl .map(|(id, (document, _))| (id, document)) .collect::>(); - let embeddings = stream::iter(self.documents.clone().into_iter().enumerate()) - // Flatten the documents + let embeddings = stream::iter(self.documents.iter().enumerate()) + // Merge the embedding targets of each document into a single list of embedding targets. .flat_map(|(i, (_, embed_targets))| { - stream::iter( - embed_targets - .into_iter() - .map(move |target| (i, target.clone())), - ) + stream::iter(embed_targets.iter().map(move |target| (i, target.clone()))) }) - // Chunk them into N (the emebdding API limit per request) + // Chunk them into N (the emebdding API limit per request). .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings + // Generate the embeddings for a chunk at a time. .map(|docs| async { - let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + let (document_indices, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + Ok::<_, EmbeddingError>( - documents + document_indices .into_iter() .zip(self.model.embed_documents(embed_targets).await?.into_iter()) .collect::>(), @@ -213,8 +251,6 @@ impl .boxed() // Parallelize the embeddings generation over 10 concurrent requests .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - // .try_collect::>() - // .await; .try_fold( HashMap::new(), |mut acc: HashMap<_, Vec<_>>, embeddings| async move { @@ -242,6 +278,9 @@ impl impl EmbeddingsBuilder { + /// Generate embeddings for all documents in the builder. + /// The method only applies when documents in the builder each contain a single embedding target. + /// Returns a vector of tuples, where the first element is the document and the second element is the embedding. pub async fn build(&self) -> Result, EmbeddingError> { let embeddings = stream::iter( self.documents @@ -274,6 +313,9 @@ impl } } +////////////////////////////////////////////////////// +/// Implementations of Embeddable for common types /// +////////////////////////////////////////////////////// impl Embeddable for String { type Kind = SingleEmbedding; @@ -282,6 +324,22 @@ impl Embeddable for String { } } +impl Embeddable for i8 { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec { + vec![self.to_string()] + } +} + +impl Embeddable for i16 { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec { + vec![self.to_string()] + } +} + impl Embeddable for i32 { type Kind = SingleEmbedding; @@ -290,6 +348,54 @@ impl Embeddable for i32 { } } +impl Embeddable for i64 { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec { + vec![self.to_string()] + } +} + +impl Embeddable for i128 { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec { + vec![self.to_string()] + } +} + +impl Embeddable for f32 { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec { + vec![self.to_string()] + } +} + +impl Embeddable for f64 { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec { + vec![self.to_string()] + } +} + +impl Embeddable for bool { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec { + vec![self.to_string()] + } +} + +impl Embeddable for char { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Vec { + vec![self.to_string()] + } +} + impl Embeddable for Vec { type Kind = ManyEmbedding; From efa2b65427fddcab60734fa023335c38022047bf Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 10 Oct 2024 14:41:57 -0400 Subject: [PATCH 12/47] refactor: derive macro --- Cargo.lock | 1 + rig-core/examples/rag.rs | 2 +- rig-core/examples/vector_search.rs | 2 +- rig-core/examples/vector_search_cohere.rs | 7 +- rig-core/rig-core-derive/Cargo.toml | 1 + rig-core/rig-core-derive/src/custom.rs | 98 ++++++++++ rig-core/rig-core-derive/src/embeddable.rs | 183 ++++++------------ rig-core/rig-core-derive/src/lib.rs | 5 + rig-core/src/embeddings.rs | 30 ++- rig-lancedb/examples/fixtures/lib.rs | 2 +- rig-mongodb/examples/vector_search_mongodb.rs | 2 +- 11 files changed, 197 insertions(+), 136 deletions(-) create mode 100644 rig-core/rig-core-derive/src/custom.rs diff --git a/Cargo.lock b/Cargo.lock index 311b2862..c47f71b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4012,6 +4012,7 @@ name = "rig-derive" version = "0.1.0" dependencies = [ "indoc", + "proc-macro2", "quote", "syn 2.0.79", ] diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index d24e293f..82e9d4fd 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -2,7 +2,7 @@ use std::{env, vec}; use rig::{ completion::Prompt, - embeddings::EmbeddingsBuilder, + embeddings::{Embeddable, EmbeddingsBuilder, ManyEmbedding}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, }; diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 8e23d274..523719ac 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - embeddings::EmbeddingsBuilder, + embeddings::{Embeddable, EmbeddingsBuilder, ManyEmbedding}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, }; diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 94ca4e4b..f9f84175 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - embeddings::EmbeddingsBuilder, + embeddings::{Embeddable, EmbeddingsBuilder, ManyEmbedding}, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, }; @@ -69,7 +69,10 @@ async fn main() -> Result<(), anyhow::Error> { .index(search_model); let results = index - .top_n::("Which instrument is found in the Nebulon Mountain Ranges?", 1) + .top_n::( + "Which instrument is found in the Nebulon Mountain Ranges?", + 1, + ) .await? .into_iter() .map(|(score, id, doc)| (score, id, doc.word)) diff --git a/rig-core/rig-core-derive/Cargo.toml b/rig-core/rig-core-derive/Cargo.toml index 008f492b..1ab5e5ac 100644 --- a/rig-core/rig-core-derive/Cargo.toml +++ b/rig-core/rig-core-derive/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] indoc = "2.0.5" +proc-macro2 = { version = "1.0.87", features = ["proc-macro"] } quote = "1.0.37" syn = { version = "2.0.79", features = ["full"]} diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs new file mode 100644 index 00000000..eb6478bd --- /dev/null +++ b/rig-core/rig-core-derive/src/custom.rs @@ -0,0 +1,98 @@ +use quote::ToTokens; +use syn::{meta::ParseNestedMeta, spanned::Spanned, ExprPath}; + +use crate::EMBED; + +const EMBED_WITH: &str = "embed_with"; + +pub(crate) trait CustomAttributeParser { + // Determine if field is tagged with an #[embed(embed_with = "...")] attribute. + fn is_custom(&self) -> syn::Result; + + // Get the "..." part of the #[embed(embed_with = "...")] attribute. + // Ex: If attribute is tagged with #[embed(embed_with = "my_embed")], returns "my_embed". + fn expand_tag(&self) -> syn::Result; +} + +impl CustomAttributeParser for syn::Attribute { + fn is_custom(&self) -> syn::Result { + // Check that the attribute is a list. + match &self.meta { + syn::Meta::List(meta) => { + if meta.tokens.is_empty() { + return Ok(false); + } + } + _ => return Ok(false), + }; + + // Check the first attribute tag (the first "embed") + if !self.path().is_ident(EMBED) { + return Ok(false); + } + + self.parse_nested_meta(|meta| { + // Parse the meta attribute as an expression. Need this to compile. + meta.value()?.parse::()?; + + if meta.path.is_ident(EMBED_WITH) { + Ok(()) + } else { + let path = meta.path.to_token_stream().to_string().replace(' ', ""); + Err(syn::Error::new_spanned( + meta.path, + format_args!("unknown embedding field attribute `{}`", path), + )) + } + })?; + + Ok(true) + } + + fn expand_tag(&self) -> syn::Result { + let mut custom_func_path = None; + + self.parse_nested_meta(|meta| match function_path(&meta) { + Ok(path) => { + custom_func_path = Some(path); + Ok(()) + } + Err(e) => Err(e), + })?; + + Ok(custom_func_path.unwrap()) + } +} + +fn function_path(meta: &ParseNestedMeta<'_>) -> syn::Result { + // #[embed(embed_with = "...")] + let expr = meta.value()?.parse::().unwrap(); + let mut value = &expr; + while let syn::Expr::Group(e) = value { + value = &e.expr; + } + let string = if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit_str), + .. + }) = value + { + let suffix = lit_str.suffix(); + if !suffix.is_empty() { + return Err(syn::Error::new( + lit_str.span(), + format!("unexpected suffix `{}` on string literal", suffix), + )); + } + lit_str.clone() + } else { + return Err(syn::Error::new( + value.span(), + format!( + "expected {} attribute to be a string: `{} = \"...\"`", + EMBED_WITH, EMBED_WITH + ), + )); + }; + + string.parse() +} diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index 8546f84a..cff035d1 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -1,55 +1,49 @@ -use proc_macro::TokenStream; +use proc_macro2::TokenStream; use quote::quote; -use syn::{ - meta::ParseNestedMeta, parse_quote, parse_str, punctuated::Punctuated, spanned::Spanned, - Attribute, DataStruct, ExprPath, Meta, Token, -}; +use syn::{parse_quote, parse_str, Attribute, DataStruct, Meta}; -const EMBED: &str = "embed"; -const EMBED_WITH: &str = "embed_with"; +use crate::{custom::CustomAttributeParser, EMBED}; const VEC_TYPE: &str = "Vec"; +const MANY_EMBEDDING: &str = "ManyEmbedding"; +const SINGLE_EMBEDDING: &str = "SingleEmbedding"; -pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { +pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result { let name = &input.ident; - let (func_calls, embed_kind) = - match &input.data { - syn::Data::Struct(data_struct) => { - // Handles fields tagged with #[embed] - let mut function_calls = data_struct - .basic_embed_fields() - .map(|field| { - add_struct_bounds(&mut input.generics, &field.ty); + let (embed_targets, embed_kind) = match &input.data { + syn::Data::Struct(data_struct) => { + // Handles fields tagged with #[embed] + let mut inner_embed_targets = data_struct + .basic_embed_fields() + .map(|field| { + add_struct_bounds(&mut input.generics, &field.ty); - let field_name = field.ident; + let field_name = field.ident; - quote! { - self.#field_name.embeddable() - } - }) - .collect::>(); + quote! { + self.#field_name.embeddable() + } + }) + .collect::>(); - // Handles fields tagged with #[embed(embed_with = "...")] - function_calls.extend(data_struct.custom_embed_fields().unwrap().map( - |(field, _)| { - let field_name = field.ident; + // Handles fields tagged with #[embed(embed_with = "...")] + inner_embed_targets.extend(data_struct.custom_embed_fields()?.map(|(field, _)| { + let field_name = field.ident; - quote! { - embeddable(&self.#field_name) - } - }, - )); + quote! { + embeddable(&self.#field_name) + } + })); - (function_calls, data_struct.embed_kind().unwrap()) - } - _ => panic!("Embeddable can only be derived for structs"), - }; + (inner_embed_targets, data_struct.embed_kind()?) + } + _ => panic!("Embeddable trait can only be derived for structs"), + }; // Import the paths to the custom functions. let custom_func_paths = match &input.data { syn::Data::Struct(data_struct) => data_struct - .custom_embed_fields() - .unwrap() + .custom_embed_fields()? .map(|(_, custom_func_path)| { quote! { use #custom_func_path::embeddable; @@ -59,11 +53,17 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { _ => vec![], }; + // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. + // ie. do not implement Embeddable trait for the struct. + if embed_targets.is_empty() { + return Ok(TokenStream::new()); + } + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { - use rig::embeddings::Embeddable; - use rig::embeddings::#embed_kind; + // Note: we do NOT import the Embeddable trait here because if there are multiple structs in the same file + // that derive Embed, there will be import conflicts. #(#custom_func_paths);* @@ -72,14 +72,14 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> TokenStream { fn embeddable(&self) -> Vec { vec![ - #(#func_calls),* + #(#embed_targets),* ].into_iter().flatten().collect() } } }; eprintln!("Generated code:\n{}", gen); - gen.into() + Ok(gen) } // Adds bounds to where clause that force all fields tagged with #[embed] to implement the Embeddable trait. @@ -91,36 +91,35 @@ fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { }); } -fn embed_kind(field: &syn::Field) -> Result { +fn embed_kind(field: &syn::Field) -> syn::Result { match &field.ty { syn::Type::Path(path) => { if path.path.segments.first().unwrap().ident == VEC_TYPE { - parse_str("ManyEmbedding") + parse_str(MANY_EMBEDDING) } else { - parse_str("SingleEmbedding") + parse_str(SINGLE_EMBEDDING) } } - _ => parse_str("SingleEmbedding"), + _ => parse_str(SINGLE_EMBEDDING), } } -trait AttributeParser { +trait StructParser { /// Finds and returns fields with simple #[embed] attribute tags only. fn basic_embed_fields(&self) -> impl Iterator; /// Finds and returns fields with #[embed(embed_with = "...")] attribute tags only. /// Also returns the attribute in question. - fn custom_embed_fields( - &self, - ) -> Result, syn::Error>; + fn custom_embed_fields(&self) + -> syn::Result>; /// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is 1, /// returns the kind of embedding that field should be. /// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is greater than 1, /// return ManyEmbedding. - fn embed_kind(&self) -> Result { + fn embed_kind(&self) -> syn::Result { let fields = self .basic_embed_fields() - .chain(self.custom_embed_fields().unwrap().map(|(f, _)| f)) + .chain(self.custom_embed_fields()?.map(|(f, _)| f)) .collect::>(); if fields.len() == 1 { @@ -131,7 +130,7 @@ trait AttributeParser { } } -impl AttributeParser for DataStruct { +impl StructParser for DataStruct { fn basic_embed_fields(&self) -> impl Iterator { self.fields.clone().into_iter().filter(|field| { field @@ -150,42 +149,7 @@ impl AttributeParser for DataStruct { fn custom_embed_fields( &self, - ) -> Result, syn::Error> { - // Determine if field is tagged with #[embed(embed_with = "...")] attribute. - fn is_custom_embed(attribute: &syn::Attribute) -> Result { - let is_custom_embed = match attribute.meta { - Meta::List(_) => attribute - .parse_args_with(Punctuated::::parse_terminated)? - .into_iter() - .any(|meta| meta.path().is_ident(EMBED_WITH)), - _ => false, - }; - - Ok(attribute.path().is_ident(EMBED) && is_custom_embed) - } - - // Get the "..." part of the #[embed(embed_with = "...")] attribute. - // Ex: If attribute is tagged with #[embed(embed_with = "my_embed")], returns "my_embed". - fn expand_tag(attribute: &syn::Attribute) -> Result { - let mut custom_func_path = None; - - attribute.parse_nested_meta(|meta| { - custom_func_path = Some(meta.function_path()?); - Ok(()) - })?; - - match custom_func_path { - Some(path) => Ok(path), - None => Err(syn::Error::new( - attribute.span(), - format!( - "expected {} attribute to have format: `#[embed(embed_with = \"...\")]`", - EMBED_WITH - ), - )), - } - } - + ) -> syn::Result> { Ok(self .fields .clone() @@ -196,8 +160,8 @@ impl AttributeParser for DataStruct { .clone() .into_iter() .map(|attribute| { - if is_custom_embed(&attribute)? { - Ok::<_, syn::Error>(Some((field.clone(), expand_tag(&attribute)?))) + if attribute.is_custom()? { + Ok::<_, syn::Error>(Some((field.clone(), attribute.expand_tag()?))) } else { Ok(None) } @@ -210,42 +174,3 @@ impl AttributeParser for DataStruct { .flatten()) } } - -trait CustomFunction { - fn function_path(&self) -> Result; -} - -impl CustomFunction for ParseNestedMeta<'_> { - fn function_path(&self) -> Result { - // #[embed(embed_with = "...")] - let expr = self.value().unwrap().parse::().unwrap(); - let mut value = &expr; - while let syn::Expr::Group(e) = value { - value = &e.expr; - } - let string = if let syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(lit_str), - .. - }) = value - { - let suffix = lit_str.suffix(); - if !suffix.is_empty() { - return Err(syn::Error::new( - lit_str.span(), - format!("unexpected suffix `{}` on string literal", suffix), - )); - } - lit_str.clone() - } else { - return Err(syn::Error::new( - value.span(), - format!( - "expected {} attribute to be a string: `{} = \"...\"`", - EMBED_WITH, EMBED_WITH - ), - )); - }; - - string.parse() - } -} diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index 19f6a845..ad07592e 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -2,8 +2,11 @@ extern crate proc_macro; use proc_macro::TokenStream; use syn::{parse_macro_input, DeriveInput}; +mod custom; mod embeddable; +pub(crate) const EMBED: &str = "embed"; + // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html @@ -12,4 +15,6 @@ pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); embeddable::expand_derive_embedding(&mut input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index d68ae565..8e943982 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -240,7 +240,7 @@ impl // Generate the embeddings for a chunk at a time. .map(|docs| async { let (document_indices, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - + Ok::<_, EmbeddingError>( document_indices .into_iter() @@ -403,3 +403,31 @@ impl Embeddable for Vec { self.iter().flat_map(|i| i.embeddable()).collect() } } + +#[cfg(test)] +mod tests { + use super::{Embeddable, SingleEmbedding}; + + use rig_derive::Embed; + use serde::Serialize; + + // #[derive(Serialize)] + // struct FakeDefinition2 { + // id: String, + // #[serde(test = "")] + // definition: String, + // } + + #[derive(Embed)] + struct FakeDefinition { + id: String, + #[embed(something = "a")] + definition: String, + } + + #[test] + fn test_missing_embed_fields() {} + + #[test] + fn test_empty_custom_function() {} +} diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index e8ebead9..747e23a8 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; -use rig::embeddings::Embedding; +use rig::embeddings::{Embedding, Embeddable, SingleEmbedding}; use rig_derive::Embed; use serde::Deserialize; diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index d4593b47..f740a29d 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use std::env; use rig::{ - embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, + embeddings::{EmbeddingsBuilder, Embeddable, SingleEmbedding}, providers::openai::Client, vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; From 83251640c1ef23ee97c5e7336fd690589075d284 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 10 Oct 2024 20:08:45 -0400 Subject: [PATCH 13/47] feat: add error type on embeddable trait --- rig-core/rig-core-derive/src/custom.rs | 68 ++++++------- rig-core/rig-core-derive/src/embeddable.rs | 54 ++++++---- rig-core/src/embeddings.rs | 99 +++++++++++-------- .../examples/vector_search_local_ann.rs | 2 +- 4 files changed, 127 insertions(+), 96 deletions(-) diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs index eb6478bd..77f321f6 100644 --- a/rig-core/rig-core-derive/src/custom.rs +++ b/rig-core/rig-core-derive/src/custom.rs @@ -1,5 +1,5 @@ use quote::ToTokens; -use syn::{meta::ParseNestedMeta, spanned::Spanned, ExprPath}; +use syn::{meta::ParseNestedMeta, ExprPath}; use crate::EMBED; @@ -50,6 +50,39 @@ impl CustomAttributeParser for syn::Attribute { } fn expand_tag(&self) -> syn::Result { + fn function_path(meta: &ParseNestedMeta<'_>) -> syn::Result { + // #[embed(embed_with = "...")] + let expr = meta.value()?.parse::().unwrap(); + let mut value = &expr; + while let syn::Expr::Group(e) = value { + value = &e.expr; + } + let string = if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit_str), + .. + }) = value + { + let suffix = lit_str.suffix(); + if !suffix.is_empty() { + return Err(syn::Error::new_spanned( + lit_str, + format!("unexpected suffix `{}` on string literal", suffix), + )); + } + lit_str.clone() + } else { + return Err(syn::Error::new_spanned( + value, + format!( + "expected {} attribute to be a string: `{} = \"...\"`", + EMBED_WITH, EMBED_WITH + ), + )); + }; + + string.parse() + } + let mut custom_func_path = None; self.parse_nested_meta(|meta| match function_path(&meta) { @@ -63,36 +96,3 @@ impl CustomAttributeParser for syn::Attribute { Ok(custom_func_path.unwrap()) } } - -fn function_path(meta: &ParseNestedMeta<'_>) -> syn::Result { - // #[embed(embed_with = "...")] - let expr = meta.value()?.parse::().unwrap(); - let mut value = &expr; - while let syn::Expr::Group(e) = value { - value = &e.expr; - } - let string = if let syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(lit_str), - .. - }) = value - { - let suffix = lit_str.suffix(); - if !suffix.is_empty() { - return Err(syn::Error::new( - lit_str.span(), - format!("unexpected suffix `{}` on string literal", suffix), - )); - } - lit_str.clone() - } else { - return Err(syn::Error::new( - value.span(), - format!( - "expected {} attribute to be a string: `{} = \"...\"`", - EMBED_WITH, EMBED_WITH - ), - )); - }; - - string.parse() -} diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index cff035d1..f5c6e781 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -10,10 +10,10 @@ const SINGLE_EMBEDDING: &str = "SingleEmbedding"; pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result { let name = &input.ident; - let (embed_targets, embed_kind) = match &input.data { + let (embed_targets, custom_embed_targets, embed_kind) = match &input.data { syn::Data::Struct(data_struct) => { // Handles fields tagged with #[embed] - let mut inner_embed_targets = data_struct + let embed_targets = data_struct .basic_embed_fields() .map(|field| { add_struct_bounds(&mut input.generics, &field.ty); @@ -21,25 +21,38 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result>(); // Handles fields tagged with #[embed(embed_with = "...")] - inner_embed_targets.extend(data_struct.custom_embed_fields()?.map(|(field, _)| { - let field_name = field.ident; + let custom_embed_targets = data_struct + .custom_embed_fields()? + .map(|(field, _)| { + let field_name = field.ident; - quote! { - embeddable(&self.#field_name) - } - })); + quote! { + self.#field_name + } + }) + .collect::>(); - (inner_embed_targets, data_struct.embed_kind()?) + ( + embed_targets, + custom_embed_targets, + data_struct.embed_kind()?, + ) } _ => panic!("Embeddable trait can only be derived for structs"), }; + // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. + // ie. do not implement Embeddable trait for the struct. + if embed_targets.is_empty() && custom_embed_targets.is_empty() { + return Ok(TokenStream::new()); + } + // Import the paths to the custom functions. let custom_func_paths = match &input.data { syn::Data::Struct(data_struct) => data_struct @@ -53,12 +66,6 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result vec![], }; - // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. - // ie. do not implement Embeddable trait for the struct. - if embed_targets.is_empty() { - return Ok(TokenStream::new()); - } - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { @@ -69,11 +76,18 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result Result, Self::Error> { + vec![#(#embed_targets.clone()),*].embeddable() + + // let custom_embed_targets = vec![#( embeddable( #embed_targets ); ),*] + // .iter() + // .collect::, _>>()? + // .into_iter() + // .flatten(); - fn embeddable(&self) -> Vec { - vec![ - #(#embed_targets),* - ].into_iter().flatten().collect() + // Ok(embed_targets.chain(custom_embed_targets).collect()) } } }; diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index 8e943982..7e25a69a 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -172,9 +172,17 @@ impl EmbeddingKind for SingleEmbedding {} pub struct ManyEmbedding; impl EmbeddingKind for ManyEmbedding {} +#[derive(Debug, thiserror::Error)] +pub enum EmbeddingGenerationError { + #[error("SerdeError: {0}")] + SerdeError(#[from] serde_json::Error), +} + pub trait Embeddable { type Kind: EmbeddingKind; - fn embeddable(&self) -> Vec; + type Error: std::error::Error; + + fn embeddable(&self) -> Result, Self::Error>; } /// Builder for creating a collection of embeddings @@ -195,22 +203,22 @@ impl, K: EmbeddingKind> EmbeddingsBui } /// Add a document that implements `Embeddable` to the builder. - pub fn document(mut self, document: D) -> Self { - let embed_targets = document.embeddable(); + pub fn document(mut self, document: D) -> Result { + let embed_targets = document.embeddable()?; self.documents.push((document, embed_targets)); - self + Ok(self) } /// Add many documents that implement `Embeddable` to the builder. - pub fn documents(mut self, documents: Vec) -> EmbeddingsBuilder { - documents.into_iter().for_each(|doc| { - let embed_targets = doc.embeddable(); + pub fn documents(mut self, documents: Vec) -> Result { + for doc in documents.into_iter() { + let embed_targets = doc.embeddable()?; self.documents.push((doc, embed_targets)); - }); + } - self + Ok(self) } } @@ -318,110 +326,119 @@ impl ////////////////////////////////////////////////////// impl Embeddable for String { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec { - vec![self.clone()] + fn embeddable(&self) -> Result, Self::Error> { + Ok(vec![self.clone()]) } } impl Embeddable for i8 { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec { - vec![self.to_string()] + fn embeddable(&self) -> Result, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for i16 { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec { - vec![self.to_string()] + fn embeddable(&self) -> Result, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for i32 { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec { - vec![self.to_string()] + fn embeddable(&self) -> Result, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for i64 { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec { - vec![self.to_string()] + fn embeddable(&self) -> Result, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for i128 { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec { - vec![self.to_string()] + fn embeddable(&self) -> Result, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for f32 { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec { - vec![self.to_string()] + fn embeddable(&self) -> Result, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for f64 { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec { - vec![self.to_string()] + fn embeddable(&self) -> Result, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for bool { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec { - vec![self.to_string()] + fn embeddable(&self) -> Result, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for char { type Kind = SingleEmbedding; + type Error = EmbeddingGenerationError; - fn embeddable(&self) -> Vec { - vec![self.to_string()] + fn embeddable(&self) -> Result, Self::Error> { + Ok(vec![self.to_string()]) } } impl Embeddable for Vec { type Kind = ManyEmbedding; + type Error = T::Error; - fn embeddable(&self) -> Vec { - self.iter().flat_map(|i| i.embeddable()).collect() + fn embeddable(&self) -> Result, Self::Error> { + Ok(self + .iter() + .map(|i| i.embeddable()) + .collect::, _>>()? + .into_iter() + .flatten() + .collect()) } } #[cfg(test)] mod tests { - use super::{Embeddable, SingleEmbedding}; + use super::{Embeddable, SingleEmbedding, EmbeddingGenerationError}; use rig_derive::Embed; - use serde::Serialize; - - // #[derive(Serialize)] - // struct FakeDefinition2 { - // id: String, - // #[serde(test = "")] - // definition: String, - // } #[derive(Embed)] struct FakeDefinition { id: String, - #[embed(something = "a")] + #[embed] definition: String, } diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index af82ad49..7a9c8cc1 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -27,7 +27,7 @@ async fn main() -> Result<(), anyhow::Error> { // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions()) + .documents(fake_definitions())? // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) From de022c49c5030f5dc9f0ef8cb861f8ee7c9a577c Mon Sep 17 00:00:00 2001 From: Garance Date: Fri, 11 Oct 2024 11:09:41 -0400 Subject: [PATCH 14/47] refactor: move embeddings to its own module and seperate embeddable --- rig-core/examples/rag.rs | 8 +- rig-core/examples/vector_search.rs | 8 +- rig-core/examples/vector_search_cohere.rs | 8 +- rig-core/rig-core-derive/src/lib.rs | 2 +- .../embeddable.rs} | 130 +++--------------- rig-core/src/embeddings/embedding.rs | 99 +++++++++++++ rig-core/src/embeddings/mod.rs | 7 + rig-core/src/lib.rs | 4 + rig-core/src/providers/cohere.rs | 12 +- rig-core/src/providers/openai.rs | 18 ++- rig-core/src/vector_store/in_memory_store.rs | 2 +- rig-core/src/vector_store/mod.rs | 2 +- rig-lancedb/examples/fixtures/lib.rs | 9 +- .../examples/vector_search_local_ann.rs | 4 +- .../examples/vector_search_local_enn.rs | 4 +- rig-lancedb/examples/vector_search_s3_ann.rs | 6 +- rig-lancedb/src/lib.rs | 2 +- rig-mongodb/examples/vector_search_mongodb.rs | 10 +- rig-mongodb/src/lib.rs | 2 +- 19 files changed, 182 insertions(+), 155 deletions(-) rename rig-core/src/{embeddings.rs => embeddings/embeddable.rs} (73%) create mode 100644 rig-core/src/embeddings/embedding.rs create mode 100644 rig-core/src/embeddings/mod.rs diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 82e9d4fd..9ff579b0 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -2,16 +2,16 @@ use std::{env, vec}; use rig::{ completion::Prompt, - embeddings::{Embeddable, EmbeddingsBuilder, ManyEmbedding}, + embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, ManyEmbedding}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, + Embeddable, }; -use rig_derive::Embed; use serde::Serialize; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embed, Clone, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(Embeddable, Clone, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, #[embed] @@ -49,7 +49,7 @@ async fn main() -> Result<(), anyhow::Error> { "Definition of a *linglingdong*: A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() ] }, - ]) + ])? .build() .await?; diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 523719ac..fe79a3f1 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,16 +1,16 @@ use std::env; use rig::{ - embeddings::{Embeddable, EmbeddingsBuilder, ManyEmbedding}, + embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, ManyEmbedding}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, + Embeddable, }; -use rig_derive::Embed; use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, word: String, @@ -52,7 +52,7 @@ async fn main() -> Result<(), anyhow::Error> { "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() ] }, - ]) + ])? .build() .await?; diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index f9f84175..9432df39 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,16 +1,16 @@ use std::env; use rig::{ - embeddings::{Embeddable, EmbeddingsBuilder, ManyEmbedding}, + embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, ManyEmbedding}, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, + Embeddable, }; -use rig_derive::Embed; use serde::{Deserialize, Serialize}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] struct FakeDefinition { id: String, word: String, @@ -53,7 +53,7 @@ async fn main() -> Result<(), anyhow::Error> { "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() ] }, - ]) + ])? .build() .await?; diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index ad07592e..fba92f1c 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -10,7 +10,7 @@ pub(crate) const EMBED: &str = "embed"; // https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro // https://doc.rust-lang.org/reference/procedural-macros.html -#[proc_macro_derive(Embed, attributes(embed))] +#[proc_macro_derive(Embeddable, attributes(embed))] pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings/embeddable.rs similarity index 73% rename from rig-core/src/embeddings.rs rename to rig-core/src/embeddings/embeddable.rs index 7e25a69a..08e8b316 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -1,20 +1,6 @@ -//! This module provides functionality for working with embeddings and embedding models. -//! Embeddings are numerical representations of documents or other objects, typically used in -//! natural language processing (NLP) tasks such as text classification, information retrieval, -//! and document similarity. -//! -//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can -//! generate embeddings for documents. It also provides an implementation of the [EmbeddingsBuilder] -//! struct, which allows users to build collections of document embeddings using different embedding -//! models and document sources. -//! -//! The module also defines the [Embedding] struct, which represents a single document embedding. -//! -//! The module also defines the [Embeddable] trait, which represents types that can be embedded. -//! Only types that implement the Embeddable trait can be used with the EmbeddingsBuilder. -//! -//! Finally, the module defines the [EmbeddingError] enum, which represents various errors that -//! can occur during embedding generation or processing. +//! The module defines the [Embeddable] trait, which must be implemented for types that can be embedded. +//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. +//! Only types that implement the [Embeddable] trait can be added to the [EmbeddingsBuilder]. //! //! # Example //! ```rust @@ -73,102 +59,18 @@ //! // ... //! ``` -use std::{cmp::max, collections::HashMap, marker::PhantomData}; - +use super::embedding::{Embedding, EmbeddingError, EmbeddingModel}; use futures::{stream, StreamExt, TryStreamExt}; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, thiserror::Error)] -pub enum EmbeddingError { - /// Http error (e.g.: connection error, timeout, etc.) - #[error("HttpError: {0}")] - HttpError(#[from] reqwest::Error), - - /// Json error (e.g.: serialization, deserialization) - #[error("JsonError: {0}")] - JsonError(#[from] serde_json::Error), - - /// Error processing the document for embedding - #[error("DocumentError: {0}")] - DocumentError(String), - - /// Error parsing the completion response - #[error("ResponseError: {0}")] - ResponseError(String), - - /// Error returned by the embedding model provider - #[error("ProviderError: {0}")] - ProviderError(String), -} - -/// Trait for embedding models that can generate embeddings for documents. -pub trait EmbeddingModel: Clone + Sync + Send { - /// The maximum number of documents that can be embedded in a single request. - const MAX_DOCUMENTS: usize; - - /// The number of dimensions in the embedding vector. - fn ndims(&self) -> usize; - - /// Embed a single document - fn embed_document( - &self, - document: &str, - ) -> impl std::future::Future> + Send - where - Self: Sync, - { - async { - Ok(self - .embed_documents(vec![document.to_string()]) - .await? - .first() - .cloned() - .expect("One embedding should be present")) - } - } - - /// Embed multiple documents in a single request - fn embed_documents( - &self, - documents: Vec, - ) -> impl std::future::Future, EmbeddingError>> + Send; -} - -/// Struct that holds a single document and its embedding. -#[derive(Clone, Default, Deserialize, Serialize, Debug)] -pub struct Embedding { - /// The document that was embedded. Used for debugging. - pub document: String, - /// The embedding vector - pub vec: Vec, -} - -impl PartialEq for Embedding { - fn eq(&self, other: &Self) -> bool { - self.document == other.document - } -} - -impl Eq for Embedding {} - -impl Embedding { - pub fn distance(&self, other: &Self) -> f64 { - let dot_product: f64 = self - .vec - .iter() - .zip(other.vec.iter()) - .map(|(x, y)| x * y) - .sum(); - - let product_of_lengths = (self.vec.len() * other.vec.len()) as f64; - - dot_product / product_of_lengths - } -} +use std::{cmp::max, collections::HashMap, marker::PhantomData}; +/// The associated type `Kind` on the trait `Embeddable` must implement this trait. pub trait EmbeddingKind {} + +/// Used for structs that contain a single embedding target. pub struct SingleEmbedding; impl EmbeddingKind for SingleEmbedding {} + +/// Used for structs that contain many embedding targets. pub struct ManyEmbedding; impl EmbeddingKind for ManyEmbedding {} @@ -178,6 +80,9 @@ pub enum EmbeddingGenerationError { SerdeError(#[from] serde_json::Error), } +/// Trait for types that can be embedded. +/// The `embeddable` method returns a list of strings for which embeddings will be generated by the embeddings builder. +/// If the type `Kind` is `SingleEmbedding`, the list of strings contains a single item, otherwise, the list can contain many items. pub trait Embeddable { type Kind: EmbeddingKind; type Error: std::error::Error; @@ -185,7 +90,7 @@ pub trait Embeddable { fn embeddable(&self) -> Result, Self::Error>; } -/// Builder for creating a collection of embeddings +/// Builder for creating a collection of embeddings. pub struct EmbeddingsBuilder { kind: PhantomData, model: M, @@ -431,11 +336,10 @@ impl Embeddable for Vec { #[cfg(test)] mod tests { - use super::{Embeddable, SingleEmbedding, EmbeddingGenerationError}; - - use rig_derive::Embed; + use super::{Embeddable, EmbeddingGenerationError, SingleEmbedding}; + use rig_derive::Embeddable; - #[derive(Embed)] + #[derive(Embeddable)] struct FakeDefinition { id: String, #[embed] diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs new file mode 100644 index 00000000..ff284a05 --- /dev/null +++ b/rig-core/src/embeddings/embedding.rs @@ -0,0 +1,99 @@ +//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can +//! generate embeddings for documents. It also provides an implementation of the [EmbeddingsBuilder] +//! struct, which allows users to build collections of document embeddings using different embedding +//! models and document sources. +//! +//! The module also defines the [Embedding] struct, which represents a single document embedding. +//! +//! Finally, the module defines the [EmbeddingError] enum, which represents various errors that +//! can occur during embedding generation or processing. + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, thiserror::Error)] +pub enum EmbeddingError { + /// Http error (e.g.: connection error, timeout, etc.) + #[error("HttpError: {0}")] + HttpError(#[from] reqwest::Error), + + /// Json error (e.g.: serialization, deserialization) + #[error("JsonError: {0}")] + JsonError(#[from] serde_json::Error), + + /// Error processing the document for embedding + #[error("DocumentError: {0}")] + DocumentError(String), + + /// Error parsing the completion response + #[error("ResponseError: {0}")] + ResponseError(String), + + /// Error returned by the embedding model provider + #[error("ProviderError: {0}")] + ProviderError(String), +} + +/// Trait for embedding models that can generate embeddings for documents. +pub trait EmbeddingModel: Clone + Sync + Send { + /// The maximum number of documents that can be embedded in a single request. + const MAX_DOCUMENTS: usize; + + /// The number of dimensions in the embedding vector. + fn ndims(&self) -> usize; + + /// Embed a single document + fn embed_document( + &self, + document: &str, + ) -> impl std::future::Future> + Send + where + Self: Sync, + { + async { + Ok(self + .embed_documents(vec![document.to_string()]) + .await? + .first() + .cloned() + .expect("One embedding should be present")) + } + } + + /// Embed multiple documents in a single request + fn embed_documents( + &self, + documents: Vec, + ) -> impl std::future::Future, EmbeddingError>> + Send; +} + +/// Struct that holds a single document and its embedding. +#[derive(Clone, Default, Deserialize, Serialize, Debug)] +pub struct Embedding { + /// The document that was embedded. Used for debugging. + pub document: String, + /// The embedding vector + pub vec: Vec, +} + +impl PartialEq for Embedding { + fn eq(&self, other: &Self) -> bool { + self.document == other.document + } +} + +impl Eq for Embedding {} + +impl Embedding { + pub fn distance(&self, other: &Self) -> f64 { + let dot_product: f64 = self + .vec + .iter() + .zip(other.vec.iter()) + .map(|(x, y)| x * y) + .sum(); + + let product_of_lengths = (self.vec.len() * other.vec.len()) as f64; + + dot_product / product_of_lengths + } +} diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs new file mode 100644 index 00000000..33526769 --- /dev/null +++ b/rig-core/src/embeddings/mod.rs @@ -0,0 +1,7 @@ +//! This module provides functionality for working with embeddings. +//! Embeddings are numerical representations of documents or other objects, typically used in +//! natural language processing (NLP) tasks such as text classification, information retrieval, +//! and document similarity. + +pub mod embeddable; +pub mod embedding; diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 86c25209..b7f0615e 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -75,3 +75,7 @@ pub mod json_utils; pub mod providers; pub mod tool; pub mod vector_store; + +// Export Embeddable trait and Embeddable together. +pub use embeddings::embeddable::Embeddable; +pub use rig_derive::Embeddable; diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index c93709c7..33464939 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -13,7 +13,11 @@ use std::collections::HashMap; use crate::{ agent::AgentBuilder, completion::{self, CompletionError}, - embeddings::{self, Embeddable, EmbeddingError, EmbeddingsBuilder}, + embeddings::{ + self, + embeddable::{Embeddable, EmbeddingsBuilder}, + embedding::EmbeddingError, + }, extractor::ExtractorBuilder, json_utils, }; @@ -187,7 +191,7 @@ pub struct EmbeddingModel { ndims: usize, } -impl embeddings::EmbeddingModel for EmbeddingModel { +impl embeddings::embedding::EmbeddingModel for EmbeddingModel { const MAX_DOCUMENTS: usize = 96; fn ndims(&self) -> usize { @@ -197,7 +201,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { async fn embed_documents( &self, documents: Vec, - ) -> Result, EmbeddingError> { + ) -> Result, EmbeddingError> { let response = self .client .post("/v1/embed") @@ -226,7 +230,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { .embeddings .into_iter() .zip(documents.into_iter()) - .map(|(embedding, document)| embeddings::Embedding { + .map(|(embedding, document)| embeddings::embedding::Embedding { document, vec: embedding, }) diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 6bd1711f..d23a46bf 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -11,7 +11,11 @@ use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, - embeddings::{self, Embeddable, EmbeddingError}, + embeddings::{ + self, + embeddable::{Embeddable, EmbeddingsBuilder}, + embedding::{Embedding, EmbeddingError}, + }, extractor::ExtractorBuilder, json_utils, }; @@ -121,11 +125,11 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings( + pub fn embeddings( &self, model: &str, - ) -> embeddings::EmbeddingsBuilder { - embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) + ) -> EmbeddingsBuilder { + EmbeddingsBuilder::new(self.embedding_model(model)) } /// Create a completion model with the given name. @@ -235,7 +239,7 @@ pub struct EmbeddingModel { ndims: usize, } -impl embeddings::EmbeddingModel for EmbeddingModel { +impl embeddings::embedding::EmbeddingModel for EmbeddingModel { const MAX_DOCUMENTS: usize = 1024; fn ndims(&self) -> usize { @@ -245,7 +249,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { async fn embed_documents( &self, documents: Vec, - ) -> Result, EmbeddingError> { + ) -> Result, EmbeddingError> { let response = self .client .post("/v1/embeddings") @@ -271,7 +275,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { .data .into_iter() .zip(documents.into_iter()) - .map(|(embedding, document)| embeddings::Embedding { + .map(|(embedding, document)| Embedding { document, vec: embedding.embedding, }) diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index ba2e754b..e8d8bb70 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -8,7 +8,7 @@ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; use super::{VectorStoreError, VectorStoreIndex}; -use crate::embeddings::{Embedding, EmbeddingModel}; +use crate::embeddings::embedding::{Embedding, EmbeddingModel}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 396b5514..6f112b81 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -2,7 +2,7 @@ use futures::future::BoxFuture; use serde::Deserialize; use serde_json::Value; -use crate::embeddings::EmbeddingError; +use crate::embeddings::embedding::EmbeddingError; pub mod in_memory_store; diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 747e23a8..bf7f2a33 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -2,11 +2,14 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; -use rig::embeddings::{Embedding, Embeddable, SingleEmbedding}; -use rig_derive::Embed; +use rig::embeddings::{ + embeddable::{EmbeddingGenerationError, SingleEmbedding}, + embedding::Embedding, +}; +use rig::Embeddable; use serde::Deserialize; -#[derive(Embed, Clone, Deserialize, Debug)] +#[derive(Embeddable, Clone, Deserialize, Debug)] pub struct FakeDefinition { id: String, #[embed] diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 7a9c8cc1..e29a2b77 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -5,7 +5,7 @@ use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDe use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::{EmbeddingModel, EmbeddingsBuilder}, + embeddings::{embeddable::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; @@ -33,7 +33,7 @@ async fn main() -> Result<(), anyhow::Error> { (0..256) .map(|i| fake_definition(format!("doc{}", i))) .collect(), - ) + )? .build() .await?; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 1bf69481..7ae8c757 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -3,7 +3,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definitions, schema}; use rig::{ - embeddings::{EmbeddingModel, EmbeddingsBuilder}, + embeddings::{embeddable::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, }; @@ -23,7 +23,7 @@ async fn main() -> Result<(), anyhow::Error> { // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions()) + .documents(fake_definitions())? .build() .await?; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 17c4cd7f..0381429a 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,7 +4,7 @@ use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::{EmbeddingModel, EmbeddingsBuilder}, + embeddings::{embeddable::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; @@ -33,13 +33,13 @@ async fn main() -> Result<(), anyhow::Error> { // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions()) + .documents(fake_definitions())? // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) .map(|i| fake_definition(format!("doc{}", i))) .collect(), - ) + )? .build() .await?; diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 1e8b344a..edcc51e5 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -3,7 +3,7 @@ use lancedb::{ DistanceType, }; use rig::{ - embeddings::EmbeddingModel, + embeddings::embedding::EmbeddingModel, vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index f740a29d..dab94eb9 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,17 +1,19 @@ use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; use rig::providers::openai::TEXT_EMBEDDING_ADA_002; -use rig_derive::Embed; use serde::{Deserialize, Serialize}; use std::env; +use rig::Embeddable; use rig::{ - embeddings::{EmbeddingsBuilder, Embeddable, SingleEmbedding}, providers::openai::Client, vector_store::VectorStoreIndex, + embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, SingleEmbedding}, + providers::openai::Client, + vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; // Shape of data that needs to be RAG'ed. // The definition field will be used to generate embeddings. -#[derive(Embed, Clone, Deserialize, Debug)] +#[derive(Embeddable, Clone, Deserialize, Debug)] struct FakeDefinition { id: String, #[embed] @@ -67,7 +69,7 @@ async fn main() -> Result<(), anyhow::Error> { ]; let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions) + .documents(fake_definitions)? .build() .await?; diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 5ce33105..4778e454 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -2,7 +2,7 @@ use futures::StreamExt; use mongodb::bson::{self, doc}; use rig::{ - embeddings::{Embedding, EmbeddingModel}, + embeddings::embedding::{Embedding, EmbeddingModel}, vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; From 220d9fc392542e4a698e69d27d79c36dcff94d14 Mon Sep 17 00:00:00 2001 From: Garance Date: Fri, 11 Oct 2024 17:01:31 -0400 Subject: [PATCH 15/47] refactor: split up macro into more files, fix all imports --- rig-core/examples/rag.rs | 2 +- rig-core/examples/vector_search.rs | 2 +- rig-core/examples/vector_search_cohere.rs | 2 +- rig-core/rig-core-derive/src/basic.rs | 29 ++ rig-core/rig-core-derive/src/custom.rs | 31 +- rig-core/rig-core-derive/src/embeddable.rs | 258 ++++++-------- rig-core/rig-core-derive/src/lib.rs | 1 + rig-core/src/embeddings/builder.rs | 206 +++++++++++ rig-core/src/embeddings/embeddable.rs | 335 +++++++----------- rig-core/src/embeddings/mod.rs | 1 + rig-core/src/providers/cohere.rs | 4 +- rig-core/src/providers/openai.rs | 3 +- rig-lancedb/examples/fixtures/lib.rs | 5 +- .../examples/vector_search_local_ann.rs | 2 +- .../examples/vector_search_local_enn.rs | 2 +- rig-lancedb/examples/vector_search_s3_ann.rs | 2 +- rig-mongodb/examples/vector_search_mongodb.rs | 28 +- 17 files changed, 528 insertions(+), 385 deletions(-) create mode 100644 rig-core/rig-core-derive/src/basic.rs create mode 100644 rig-core/src/embeddings/builder.rs diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 9ff579b0..43270a7b 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -2,7 +2,7 @@ use std::{env, vec}; use rig::{ completion::Prompt, - embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, ManyEmbedding}, + embeddings::builder::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, Embeddable, diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index fe79a3f1..a97ef8b0 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, ManyEmbedding}, + embeddings::builder::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, Embeddable, diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 9432df39..16ddb775 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, ManyEmbedding}, + embeddings::builder::EmbeddingsBuilder, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, Embeddable, diff --git a/rig-core/rig-core-derive/src/basic.rs b/rig-core/rig-core-derive/src/basic.rs new file mode 100644 index 00000000..7ac5bb47 --- /dev/null +++ b/rig-core/rig-core-derive/src/basic.rs @@ -0,0 +1,29 @@ +use syn::{parse_quote, Attribute, DataStruct, Meta}; + +use crate::EMBED; + +/// Finds and returns fields with simple #[embed] attribute tags only. +pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator { + data_struct.fields.clone().into_iter().filter(|field| { + field + .attrs + .clone() + .into_iter() + .any(|attribute| match attribute { + Attribute { + meta: Meta::Path(path), + .. + } => path.is_ident(EMBED), + _ => false, + }) + }) +} + +// Adds bounds to where clause that force all fields tagged with #[embed] to implement the Embeddable trait. +pub(crate) fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { + let where_clause = generics.make_where_clause(); + + where_clause.predicates.push(parse_quote! { + #field_type: Embeddable + }); +} diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs index 77f321f6..8926aebf 100644 --- a/rig-core/rig-core-derive/src/custom.rs +++ b/rig-core/rig-core-derive/src/custom.rs @@ -5,7 +5,36 @@ use crate::EMBED; const EMBED_WITH: &str = "embed_with"; -pub(crate) trait CustomAttributeParser { +/// Finds and returns fields with #[embed(embed_with = "...")] attribute tags only. +/// Also returns the attribute in question. +pub(crate) fn custom_embed_fields( + data_struct: &syn::DataStruct, +) -> syn::Result> { + Ok(data_struct + .fields + .clone() + .into_iter() + .map(|field| { + field + .attrs + .clone() + .into_iter() + .map(|attribute| { + if attribute.is_custom()? { + Ok::<_, syn::Error>(Some((field.clone(), attribute.expand_tag()?))) + } else { + Ok(None) + } + }) + .collect::, _>>() + }) + .collect::, _>>()? + .into_iter() + .flatten() + .flatten()) +} + +trait CustomAttributeParser { // Determine if field is tagged with an #[embed(embed_with = "...")] attribute. fn is_custom(&self) -> syn::Result; diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index f5c6e781..914945ec 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -1,8 +1,11 @@ use proc_macro2::TokenStream; use quote::quote; -use syn::{parse_quote, parse_str, Attribute, DataStruct, Meta}; +use syn::{parse_str, DataStruct}; -use crate::{custom::CustomAttributeParser, EMBED}; +use crate::{ + basic::{add_struct_bounds, basic_embed_fields}, + custom::custom_embed_fields, +}; const VEC_TYPE: &str = "Vec"; const MANY_EMBEDDING: &str = "ManyEmbedding"; const SINGLE_EMBEDDING: &str = "SingleEmbedding"; @@ -10,41 +13,74 @@ const SINGLE_EMBEDDING: &str = "SingleEmbedding"; pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result { let name = &input.ident; - let (embed_targets, custom_embed_targets, embed_kind) = match &input.data { - syn::Data::Struct(data_struct) => { - // Handles fields tagged with #[embed] - let embed_targets = data_struct - .basic_embed_fields() - .map(|field| { - add_struct_bounds(&mut input.generics, &field.ty); - - let field_name = field.ident; - - quote! { - self.#field_name - } - }) - .collect::>(); - - // Handles fields tagged with #[embed(embed_with = "...")] - let custom_embed_targets = data_struct - .custom_embed_fields()? - .map(|(field, _)| { - let field_name = field.ident; - - quote! { - self.#field_name - } - }) - .collect::>(); - - ( - embed_targets, - custom_embed_targets, - data_struct.embed_kind()?, - ) + // Handles fields tagged with #[embed] + let embed_targets = match &input.data { + syn::Data::Struct(data_struct) => basic_embed_fields(data_struct) + .map(|field| { + add_struct_bounds(&mut input.generics, &field.ty); + + let field_name = field.ident; + + quote! { + self.#field_name + } + }) + .collect::>(), + _ => { + return Err(syn::Error::new_spanned( + name, + "Embeddable derive macro should only be used on structs", + )) + } + }; + + let embed_targets_quote = if !embed_targets.is_empty() { + quote! { + vec![#(#embed_targets.embeddable()),*] + .into_iter() + .collect::, _>>()? + .into_iter() + .flatten() + .collect::>() + } + } else { + quote! { + vec![] + } + }; + + // Handles fields tagged with #[embed(embed_with = "...")] + let custom_embed_targets = match &input.data { + syn::Data::Struct(data_struct) => custom_embed_fields(data_struct)? + .map(|(field, custom_func_path)| { + let field_name = field.ident; + + quote! { + #custom_func_path(self.#field_name.clone()) + } + }) + .collect::>(), + _ => { + return Err(syn::Error::new_spanned( + name, + "Embeddable derive macro should only be used on structs", + )) + } + }; + + let custom_embed_targets_quote = if !custom_embed_targets.is_empty() { + quote! { + vec![#(#custom_embed_targets),*] + .into_iter() + .collect::, _>>()? + .into_iter() + .flatten() + .collect::>() + } + } else { + quote! { + vec![] } - _ => panic!("Embeddable trait can only be derived for structs"), }; // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. @@ -53,41 +89,34 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result data_struct - .custom_embed_fields()? - .map(|(_, custom_func_path)| { - quote! { - use #custom_func_path::embeddable; - } - }) - .collect::>(), - _ => vec![], + // Determine whether the Embeddable::Kind should be SinleEmbedding or ManyEmbedding. + let embed_kind = match &input.data { + syn::Data::Struct(data_struct) => embed_kind(data_struct)?, + _ => { + return Err(syn::Error::new_spanned( + name, + "Embeddable derive macro should only be used on structs", + )) + } }; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let gen = quote! { - // Note: we do NOT import the Embeddable trait here because if there are multiple structs in the same file - // that derive Embed, there will be import conflicts. - - #(#custom_func_paths);* + // Note: Embeddable trait is imported with the macro. impl #impl_generics Embeddable for #name #ty_generics #where_clause { - type Kind = #embed_kind; - type Error = EmbeddingGenerationError; + type Kind = rig::embeddings::embeddable::#embed_kind; + type Error = rig::embeddings::embeddable::EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - vec![#(#embed_targets.clone()),*].embeddable() + let mut embed_targets = #embed_targets_quote; + + let custom_embed_targets = #custom_embed_targets_quote; - // let custom_embed_targets = vec![#( embeddable( #embed_targets ); ),*] - // .iter() - // .collect::, _>>()? - // .into_iter() - // .flatten(); + embed_targets.extend(custom_embed_targets); - // Ok(embed_targets.chain(custom_embed_targets).collect()) + Ok(embed_targets) } } }; @@ -96,95 +125,30 @@ pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result syn::Result { - match &field.ty { - syn::Type::Path(path) => { - if path.path.segments.first().unwrap().ident == VEC_TYPE { - parse_str(MANY_EMBEDDING) - } else { - parse_str(SINGLE_EMBEDDING) +/// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is 1, +/// returns the kind of embedding that field should be. +/// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is greater than 1, +/// return ManyEmbedding. +fn embed_kind(data_struct: &DataStruct) -> syn::Result { + fn embed_kind(field: &syn::Field) -> syn::Result { + match &field.ty { + syn::Type::Path(path) => { + if path.path.segments.first().unwrap().ident == VEC_TYPE { + parse_str(MANY_EMBEDDING) + } else { + parse_str(SINGLE_EMBEDDING) + } } + _ => parse_str(SINGLE_EMBEDDING), } - _ => parse_str(SINGLE_EMBEDDING), - } -} - -trait StructParser { - /// Finds and returns fields with simple #[embed] attribute tags only. - fn basic_embed_fields(&self) -> impl Iterator; - /// Finds and returns fields with #[embed(embed_with = "...")] attribute tags only. - /// Also returns the attribute in question. - fn custom_embed_fields(&self) - -> syn::Result>; - - /// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is 1, - /// returns the kind of embedding that field should be. - /// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is greater than 1, - /// return ManyEmbedding. - fn embed_kind(&self) -> syn::Result { - let fields = self - .basic_embed_fields() - .chain(self.custom_embed_fields()?.map(|(f, _)| f)) - .collect::>(); - - if fields.len() == 1 { - fields.iter().map(embed_kind).next().unwrap() - } else { - parse_str("ManyEmbedding") - } - } -} - -impl StructParser for DataStruct { - fn basic_embed_fields(&self) -> impl Iterator { - self.fields.clone().into_iter().filter(|field| { - field - .attrs - .clone() - .into_iter() - .any(|attribute| match attribute { - Attribute { - meta: Meta::Path(path), - .. - } => path.is_ident(EMBED), - _ => false, - }) - }) } - - fn custom_embed_fields( - &self, - ) -> syn::Result> { - Ok(self - .fields - .clone() - .into_iter() - .map(|field| { - field - .attrs - .clone() - .into_iter() - .map(|attribute| { - if attribute.is_custom()? { - Ok::<_, syn::Error>(Some((field.clone(), attribute.expand_tag()?))) - } else { - Ok(None) - } - }) - .collect::, _>>() - }) - .collect::, _>>()? - .into_iter() - .flatten() - .flatten()) + let fields = basic_embed_fields(data_struct) + .chain(custom_embed_fields(data_struct)?.map(|(f, _)| f)) + .collect::>(); + + if fields.len() == 1 { + fields.iter().map(embed_kind).next().unwrap() + } else { + parse_str(MANY_EMBEDDING) } } diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs index fba92f1c..d28a0d78 100644 --- a/rig-core/rig-core-derive/src/lib.rs +++ b/rig-core/rig-core-derive/src/lib.rs @@ -2,6 +2,7 @@ extern crate proc_macro; use proc_macro::TokenStream; use syn::{parse_macro_input, DeriveInput}; +mod basic; mod custom; mod embeddable; diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs new file mode 100644 index 00000000..01a99fd4 --- /dev/null +++ b/rig-core/src/embeddings/builder.rs @@ -0,0 +1,206 @@ +//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. +//! Only types that implement the [Embeddable] trait can be added to the [EmbeddingsBuilder]. +//! +//! # Example +//! ```rust +//! use std::env; +//! +//! use rig::{ +//! embeddings::EmbeddingsBuilder, +//! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, +//! }; +//! use rig_derive::Embed; +//! +//! #[derive(Embed)] +//! struct FakeDefinition { +//! id: String, +//! word: String, +//! #[embed] +//! definitions: Vec, +//! } +//! // Create OpenAI client +//! let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); +//! let openai_client = Client::new(&openai_api_key); +//! +//! let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); +//! +//! let embeddings = EmbeddingsBuilder::new(model.clone()) +//! .documents(vec![ +//! FakeDefinition { +//! id: "doc0".to_string(), +//! word: "flurbo".to_string(), +//! definitions: vec![ +//! "A green alien that lives on cold planets.".to_string(), +//! "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() +//! ] +//! }, +//! FakeDefinition { +//! id: "doc1".to_string(), +//! word: "glarb-glarb".to_string(), +//! definitions: vec![ +//! "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), +//! "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() +//! ] +//! }, +//! FakeDefinition { +//! id: "doc2".to_string(), +//! word: "linglingdong".to_string(), +//! definitions: vec![ +//! "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), +//! "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() +//! ] +//! }, +//! ]) +//! .build() +//! .await?; +//! +//! // Use the generated embeddings +//! // ... +//! ``` + +use std::{cmp::max, collections::HashMap, marker::PhantomData}; + +use futures::{stream, StreamExt, TryStreamExt}; + +use crate::Embeddable; + +use super::{ + embeddable::{EmbeddingKind, ManyEmbedding, SingleEmbedding}, + embedding::{Embedding, EmbeddingError, EmbeddingModel}, +}; + +/// Builder for creating a collection of embeddings. +pub struct EmbeddingsBuilder { + kind: PhantomData, + model: M, + documents: Vec<(D, Vec)>, +} + +impl, K: EmbeddingKind> EmbeddingsBuilder { + /// Create a new embedding builder with the given embedding model + pub fn new(model: M) -> Self { + Self { + kind: PhantomData, + model, + documents: vec![], + } + } + + /// Add a document that implements `Embeddable` to the builder. + pub fn document(mut self, document: D) -> Result { + let embed_targets = document.embeddable()?; + + self.documents.push((document, embed_targets)); + Ok(self) + } + + /// Add many documents that implement `Embeddable` to the builder. + pub fn documents(mut self, documents: Vec) -> Result { + for doc in documents.into_iter() { + let embed_targets = doc.embeddable()?; + + self.documents.push((doc, embed_targets)); + } + + Ok(self) + } +} + +impl + EmbeddingsBuilder +{ + /// Generate embeddings for all documents in the builder. + /// The method only applies when documents in the builder each contain multiple embedding targets. + /// Returns a vector of tuples, where the first element is the document and the second element is the vector of embeddings. + pub async fn build(&self) -> Result)>, EmbeddingError> { + // Use this for reference later to merge a document back with its embeddings. + let documents_map = self + .documents + .clone() + .into_iter() + .enumerate() + .map(|(id, (document, _))| (id, document)) + .collect::>(); + + let embeddings = stream::iter(self.documents.iter().enumerate()) + // Merge the embedding targets of each document into a single list of embedding targets. + .flat_map(|(i, (_, embed_targets))| { + stream::iter(embed_targets.iter().map(move |target| (i, target.clone()))) + }) + // Chunk them into N (the emebdding API limit per request). + .chunks(M::MAX_DOCUMENTS) + // Generate the embeddings for a chunk at a time. + .map(|docs| async { + let (document_indices, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + + Ok::<_, EmbeddingError>( + document_indices + .into_iter() + .zip(self.model.embed_documents(embed_targets).await?.into_iter()) + .collect::>(), + ) + }) + .boxed() + // Parallelize the embeddings generation over 10 concurrent requests + .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) + .try_fold( + HashMap::new(), + |mut acc: HashMap<_, Vec<_>>, embeddings| async move { + embeddings.into_iter().for_each(|(i, embedding)| { + acc.entry(i).or_default().push(embedding); + }); + + Ok(acc) + }, + ) + .await? + .iter() + .fold(vec![], |mut acc, (i, embeddings_vec)| { + acc.push(( + documents_map.get(i).cloned().unwrap(), + embeddings_vec.clone(), + )); + acc + }); + + Ok(embeddings) + } +} + +impl + EmbeddingsBuilder +{ + /// Generate embeddings for all documents in the builder. + /// The method only applies when documents in the builder each contain a single embedding target. + /// Returns a vector of tuples, where the first element is the document and the second element is the embedding. + pub async fn build(&self) -> Result, EmbeddingError> { + let embeddings = stream::iter( + self.documents + .clone() + .into_iter() + .map(|(document, embed_target)| (document, embed_target.first().cloned().unwrap())), + ) + // Chunk them into N (the emebdding API limit per request) + .chunks(M::MAX_DOCUMENTS) + // Generate the embeddings + .map(|docs| async { + let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); + Ok::<_, EmbeddingError>( + documents + .into_iter() + .zip(self.model.embed_documents(embed_targets).await?.into_iter()) + .collect::>(), + ) + }) + .boxed() + // Parallelize the embeddings generation over 10 concurrent requests + .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) + .try_fold(vec![], |mut acc, embeddings| async move { + acc.extend(embeddings); + Ok(acc) + }) + .await?; + + Ok(embeddings) + } +} diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index 08e8b316..80e5b689 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -1,67 +1,4 @@ //! The module defines the [Embeddable] trait, which must be implemented for types that can be embedded. -//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. -//! Only types that implement the [Embeddable] trait can be added to the [EmbeddingsBuilder]. -//! -//! # Example -//! ```rust -//! use std::env; -//! -//! use rig::{ -//! embeddings::EmbeddingsBuilder, -//! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, -//! }; -//! use rig_derive::Embed; -//! -//! #[derive(Embed)] -//! struct FakeDefinition { -//! id: String, -//! word: String, -//! #[embed] -//! definitions: Vec, -//! } -//! // Create OpenAI client -//! let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); -//! let openai_client = Client::new(&openai_api_key); -//! -//! let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); -//! -//! let embeddings = EmbeddingsBuilder::new(model.clone()) -//! .documents(vec![ -//! FakeDefinition { -//! id: "doc0".to_string(), -//! word: "flurbo".to_string(), -//! definitions: vec![ -//! "A green alien that lives on cold planets.".to_string(), -//! "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() -//! ] -//! }, -//! FakeDefinition { -//! id: "doc1".to_string(), -//! word: "glarb-glarb".to_string(), -//! definitions: vec![ -//! "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), -//! "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() -//! ] -//! }, -//! FakeDefinition { -//! id: "doc2".to_string(), -//! word: "linglingdong".to_string(), -//! definitions: vec![ -//! "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), -//! "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() -//! ] -//! }, -//! ]) -//! .build() -//! .await?; -//! -//! // Use the generated embeddings -//! // ... -//! ``` - -use super::embedding::{Embedding, EmbeddingError, EmbeddingModel}; -use futures::{stream, StreamExt, TryStreamExt}; -use std::{cmp::max, collections::HashMap, marker::PhantomData}; /// The associated type `Kind` on the trait `Embeddable` must implement this trait. pub trait EmbeddingKind {} @@ -75,7 +12,7 @@ pub struct ManyEmbedding; impl EmbeddingKind for ManyEmbedding {} #[derive(Debug, thiserror::Error)] -pub enum EmbeddingGenerationError { +pub enum EmbeddableError { #[error("SerdeError: {0}")] SerdeError(#[from] serde_json::Error), } @@ -90,148 +27,12 @@ pub trait Embeddable { fn embeddable(&self) -> Result, Self::Error>; } -/// Builder for creating a collection of embeddings. -pub struct EmbeddingsBuilder { - kind: PhantomData, - model: M, - documents: Vec<(D, Vec)>, -} - -impl, K: EmbeddingKind> EmbeddingsBuilder { - /// Create a new embedding builder with the given embedding model - pub fn new(model: M) -> Self { - Self { - kind: PhantomData, - model, - documents: vec![], - } - } - - /// Add a document that implements `Embeddable` to the builder. - pub fn document(mut self, document: D) -> Result { - let embed_targets = document.embeddable()?; - - self.documents.push((document, embed_targets)); - Ok(self) - } - - /// Add many documents that implement `Embeddable` to the builder. - pub fn documents(mut self, documents: Vec) -> Result { - for doc in documents.into_iter() { - let embed_targets = doc.embeddable()?; - - self.documents.push((doc, embed_targets)); - } - - Ok(self) - } -} - -impl - EmbeddingsBuilder -{ - /// Generate embeddings for all documents in the builder. - /// The method only applies when documents in the builder each contain multiple embedding targets. - /// Returns a vector of tuples, where the first element is the document and the second element is the vector of embeddings. - pub async fn build(&self) -> Result)>, EmbeddingError> { - // Use this for reference later to merge a document back with its embeddings. - let documents_map = self - .documents - .clone() - .into_iter() - .enumerate() - .map(|(id, (document, _))| (id, document)) - .collect::>(); - - let embeddings = stream::iter(self.documents.iter().enumerate()) - // Merge the embedding targets of each document into a single list of embedding targets. - .flat_map(|(i, (_, embed_targets))| { - stream::iter(embed_targets.iter().map(move |target| (i, target.clone()))) - }) - // Chunk them into N (the emebdding API limit per request). - .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings for a chunk at a time. - .map(|docs| async { - let (document_indices, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - - Ok::<_, EmbeddingError>( - document_indices - .into_iter() - .zip(self.model.embed_documents(embed_targets).await?.into_iter()) - .collect::>(), - ) - }) - .boxed() - // Parallelize the embeddings generation over 10 concurrent requests - .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - .try_fold( - HashMap::new(), - |mut acc: HashMap<_, Vec<_>>, embeddings| async move { - embeddings.into_iter().for_each(|(i, embedding)| { - acc.entry(i).or_default().push(embedding); - }); - - Ok(acc) - }, - ) - .await? - .iter() - .fold(vec![], |mut acc, (i, embeddings_vec)| { - acc.push(( - documents_map.get(i).cloned().unwrap(), - embeddings_vec.clone(), - )); - acc - }); - - Ok(embeddings) - } -} - -impl - EmbeddingsBuilder -{ - /// Generate embeddings for all documents in the builder. - /// The method only applies when documents in the builder each contain a single embedding target. - /// Returns a vector of tuples, where the first element is the document and the second element is the embedding. - pub async fn build(&self) -> Result, EmbeddingError> { - let embeddings = stream::iter( - self.documents - .clone() - .into_iter() - .map(|(document, embed_target)| (document, embed_target.first().cloned().unwrap())), - ) - // Chunk them into N (the emebdding API limit per request) - .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings - .map(|docs| async { - let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - Ok::<_, EmbeddingError>( - documents - .into_iter() - .zip(self.model.embed_documents(embed_targets).await?.into_iter()) - .collect::>(), - ) - }) - .boxed() - // Parallelize the embeddings generation over 10 concurrent requests - .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - .try_fold(vec![], |mut acc, embeddings| async move { - acc.extend(embeddings); - Ok(acc) - }) - .await?; - - Ok(embeddings) - } -} - ////////////////////////////////////////////////////// /// Implementations of Embeddable for common types /// ////////////////////////////////////////////////////// impl Embeddable for String { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.clone()]) @@ -240,7 +41,7 @@ impl Embeddable for String { impl Embeddable for i8 { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) @@ -249,7 +50,7 @@ impl Embeddable for i8 { impl Embeddable for i16 { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) @@ -258,7 +59,7 @@ impl Embeddable for i16 { impl Embeddable for i32 { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) @@ -267,7 +68,7 @@ impl Embeddable for i32 { impl Embeddable for i64 { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) @@ -276,7 +77,7 @@ impl Embeddable for i64 { impl Embeddable for i128 { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) @@ -285,7 +86,7 @@ impl Embeddable for i128 { impl Embeddable for f32 { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) @@ -294,7 +95,7 @@ impl Embeddable for f32 { impl Embeddable for f64 { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) @@ -303,7 +104,7 @@ impl Embeddable for f64 { impl Embeddable for bool { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) @@ -312,7 +113,7 @@ impl Embeddable for bool { impl Embeddable for char { type Kind = SingleEmbedding; - type Error = EmbeddingGenerationError; + type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) @@ -336,19 +137,127 @@ impl Embeddable for Vec { #[cfg(test)] mod tests { - use super::{Embeddable, EmbeddingGenerationError, SingleEmbedding}; + use crate as rig; + use rig::embeddings::embeddable::{Embeddable, EmbeddableError}; use rig_derive::Embeddable; + use serde::Serialize; + + fn serialize(definition: Definition) -> Result, EmbeddableError> { + Ok(vec![ + serde_json::to_string(&definition).map_err(EmbeddableError::SerdeError)? + ]) + } #[derive(Embeddable)] struct FakeDefinition { id: String, + word: String, + #[embed(embed_with = "serialize")] + definition: Definition, + } + + #[derive(Serialize, Clone)] + struct Definition { + word: String, + link: String, + speech: String, + } + + #[test] + fn test_custom_embed() { + let fake_definition = FakeDefinition { + id: "doc1".to_string(), + word: "house".to_string(), + definition: Definition { + speech: "noun".to_string(), + word: "a building in which people live; residence for human beings.".to_string(), + link: "https://www.dictionary.com/browse/house".to_string(), + }, + }; + + assert_eq!( + fake_definition.embeddable().unwrap(), + vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] + ) + } + + #[derive(Embeddable)] + struct FakeDefinition2 { + id: String, + word: String, #[embed] definition: String, } #[test] - fn test_missing_embed_fields() {} + fn test_simple_embed() { + let fake_definition = FakeDefinition2 { + id: "doc1".to_string(), + word: "house".to_string(), + definition: "a building in which people live; residence for human beings.".to_string(), + }; + + assert_eq!( + fake_definition.embeddable().unwrap(), + vec!["a building in which people live; residence for human beings.".to_string()] + ); + + assert!(false) + } + + #[derive(Embeddable)] + struct Company { + id: String, + company: String, + #[embed] + employee_ages: Vec, + } + + #[test] + fn test_multiple_embed() { + let company = Company { + id: "doc1".to_string(), + company: "Google".to_string(), + employee_ages: vec![25, 30, 35, 40], + }; + + assert_eq!( + company.embeddable().unwrap(), + vec![ + "25".to_string(), + "30".to_string(), + "35".to_string(), + "40".to_string() + ] + ); + } + + #[derive(Embeddable)] + struct Company2 { + id: String, + #[embed] + company: String, + #[embed] + employee_ages: Vec, + } #[test] - fn test_empty_custom_function() {} + fn test_many_embed() { + let company = Company2 { + id: "doc1".to_string(), + company: "Google".to_string(), + employee_ages: vec![25, 30, 35, 40], + }; + + assert_eq!( + company.embeddable().unwrap(), + vec![ + "Google".to_string(), + "25".to_string(), + "30".to_string(), + "35".to_string(), + "40".to_string() + ] + ); + } } diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 33526769..37e720cb 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -3,5 +3,6 @@ //! natural language processing (NLP) tasks such as text classification, information retrieval, //! and document similarity. +pub mod builder; pub mod embeddable; pub mod embedding; diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 33464939..fa798205 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -14,9 +14,7 @@ use crate::{ agent::AgentBuilder, completion::{self, CompletionError}, embeddings::{ - self, - embeddable::{Embeddable, EmbeddingsBuilder}, - embedding::EmbeddingError, + self, builder::EmbeddingsBuilder, embeddable::Embeddable, embedding::EmbeddingError, }, extractor::ExtractorBuilder, json_utils, diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index d23a46bf..d21bf8fb 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -13,7 +13,8 @@ use crate::{ completion::{self, CompletionError, CompletionRequest}, embeddings::{ self, - embeddable::{Embeddable, EmbeddingsBuilder}, + builder::EmbeddingsBuilder, + embeddable::Embeddable, embedding::{Embedding, EmbeddingError}, }, extractor::ExtractorBuilder, diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index bf7f2a33..956ace1c 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -2,10 +2,7 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; -use rig::embeddings::{ - embeddable::{EmbeddingGenerationError, SingleEmbedding}, - embedding::Embedding, -}; +use rig::embeddings::embedding::Embedding; use rig::Embeddable; use serde::Deserialize; diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index e29a2b77..1042e34c 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -5,7 +5,7 @@ use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDe use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::{embeddable::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 7ae8c757..630acc1a 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -3,7 +3,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definitions, schema}; use rig::{ - embeddings::{embeddable::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, }; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 0381429a..8d65e37a 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,7 +4,7 @@ use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::{embeddable::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index dab94eb9..e087c3d6 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,12 +1,12 @@ use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; +use rig::embeddings::embeddable::EmbeddableError; use rig::providers::openai::TEXT_EMBEDDING_ADA_002; use serde::{Deserialize, Serialize}; use std::env; use rig::Embeddable; use rig::{ - embeddings::embeddable::{EmbeddingGenerationError, EmbeddingsBuilder, SingleEmbedding}, - providers::openai::Client, + embeddings::builder::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; @@ -20,6 +20,12 @@ struct FakeDefinition { definition: String, } +#[derive(Clone, Deserialize, Debug, Serialize)] +struct Link { + word: String, + link: String, +} + // Shape of the document to be stored in MongoDB, with embeddings. #[derive(Serialize, Debug)] struct Document { @@ -56,15 +62,15 @@ async fn main() -> Result<(), anyhow::Error> { let fake_definitions = vec![ FakeDefinition { id: "doc0".to_string(), - definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string() + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), }, FakeDefinition { id: "doc1".to_string(), - definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string() + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), }, FakeDefinition { id: "doc2".to_string(), - definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), } ]; @@ -75,11 +81,13 @@ async fn main() -> Result<(), anyhow::Error> { let mongo_documents = embeddings .iter() - .map(|(FakeDefinition { id, definition }, embedding)| Document { - id: id.clone(), - definition: definition.clone(), - embedding: embedding.vec.clone(), - }) + .map( + |(FakeDefinition { id, definition, .. }, embedding)| Document { + id: id.clone(), + definition: definition.clone(), + embedding: embedding.vec.clone(), + }, + ) .collect::>(); match collection.insert_many(mongo_documents, None).await { From 5a8c3612fbda2af47ffce55a2186ca46d3f396c9 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 15 Oct 2024 14:28:03 -0400 Subject: [PATCH 16/47] feat: handle tools with embeddingsbuilder --- rig-core/examples/calculator_chatbot.rs | 17 +- rig-core/examples/rag_dynamic_tools.rs | 13 +- rig-core/rig-core-derive/src/custom.rs | 2 +- rig-core/rig-core-derive/src/embeddable.rs | 175 ++++++++++-------- rig-core/src/embeddings/builder.rs | 6 +- rig-core/src/embeddings/embeddable.rs | 46 +++-- rig-core/src/embeddings/mod.rs | 1 + rig-core/src/embeddings/tool.rs | 25 +++ rig-core/src/tool.rs | 18 +- rig-core/src/vector_store/in_memory_store.rs | 2 +- rig-mongodb/examples/vector_search_mongodb.rs | 1 - 11 files changed, 174 insertions(+), 132 deletions(-) create mode 100644 rig-core/src/embeddings/tool.rs diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index fb168a08..0b994265 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -2,7 +2,7 @@ use anyhow::Result; use rig::{ cli_chatbot::cli_chatbot, completion::ToolDefinition, - embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, + embeddings::builder::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, @@ -25,7 +25,7 @@ struct MathError; #[error("Init error")] struct InitError; -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Clone)] struct Add; impl Tool for Add { const NAME: &'static str = "add"; @@ -77,7 +77,7 @@ impl ToolEmbedding for Add { fn context(&self) -> Self::Context {} } -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Clone)] struct Subtract; impl Tool for Subtract { const NAME: &'static str = "subtract"; @@ -247,7 +247,7 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .tools(&toolset)? + .documents(toolset.embedabble_tools()?)? .build() .await?; @@ -255,13 +255,8 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .map( - |DocumentEmbeddings { - id, - document, - embeddings, - }| { (id, document, embeddings) }, - ) + .enumerate() + .map(|(i, (tool, embedding))| (i.to_string(), tool, vec![embedding])) .collect(), )? .index(embedding_model); diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index cdf6b65e..51c56ca8 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -1,7 +1,7 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, - embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, + embeddings::builder::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, @@ -156,7 +156,7 @@ async fn main() -> Result<(), anyhow::Error> { .build(); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .tools(&toolset)? + .documents(toolset.embedabble_tools()?)? .build() .await?; @@ -164,13 +164,8 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .map( - |DocumentEmbeddings { - id, - document, - embeddings, - }| { (id, document, embeddings) }, - ) + .enumerate() + .map(|(i, (tool, embedding))| (i.to_string(), tool, vec![embedding])) .collect(), )? .index(embedding_model); diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs index 8926aebf..194be085 100644 --- a/rig-core/rig-core-derive/src/custom.rs +++ b/rig-core/rig-core-derive/src/custom.rs @@ -6,7 +6,7 @@ use crate::EMBED; const EMBED_WITH: &str = "embed_with"; /// Finds and returns fields with #[embed(embed_with = "...")] attribute tags only. -/// Also returns the attribute in question. +/// Also returns the "..." part of the tag (ie. the custom function). pub(crate) fn custom_embed_fields( data_struct: &syn::DataStruct, ) -> syn::Result> { diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index 914945ec..cd5a201f 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -10,113 +10,58 @@ const VEC_TYPE: &str = "Vec"; const MANY_EMBEDDING: &str = "ManyEmbedding"; const SINGLE_EMBEDDING: &str = "SingleEmbedding"; -pub fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result { - let name = &input.ident; +pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result { + let data = &input.data; + let generics = &mut input.generics; - // Handles fields tagged with #[embed] - let embed_targets = match &input.data { - syn::Data::Struct(data_struct) => basic_embed_fields(data_struct) - .map(|field| { - add_struct_bounds(&mut input.generics, &field.ty); - - let field_name = field.ident; + let (target_stream, embed_kind) = match data { + syn::Data::Struct(data_struct) => { + let basic_targets = data_struct.basic(generics); + let custom_targets = data_struct.custom()?; + // Determine whether the Embeddable::Kind should be SinleEmbedding or ManyEmbedding + ( quote! { - self.#field_name - } - }) - .collect::>(), - _ => { - return Err(syn::Error::new_spanned( - name, - "Embeddable derive macro should only be used on structs", - )) + let mut embed_targets = #basic_targets; + embed_targets.extend(#custom_targets) + }, + embed_kind(data_struct)?, + ) } - }; - - let embed_targets_quote = if !embed_targets.is_empty() { - quote! { - vec![#(#embed_targets.embeddable()),*] - .into_iter() - .collect::, _>>()? - .into_iter() - .flatten() - .collect::>() - } - } else { - quote! { - vec![] - } - }; - - // Handles fields tagged with #[embed(embed_with = "...")] - let custom_embed_targets = match &input.data { - syn::Data::Struct(data_struct) => custom_embed_fields(data_struct)? - .map(|(field, custom_func_path)| { - let field_name = field.ident; - - quote! { - #custom_func_path(self.#field_name.clone()) - } - }) - .collect::>(), _ => { return Err(syn::Error::new_spanned( - name, + input, "Embeddable derive macro should only be used on structs", )) } }; - let custom_embed_targets_quote = if !custom_embed_targets.is_empty() { - quote! { - vec![#(#custom_embed_targets),*] - .into_iter() - .collect::, _>>()? - .into_iter() - .flatten() - .collect::>() - } - } else { - quote! { - vec![] - } - }; - // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. // ie. do not implement Embeddable trait for the struct. - if embed_targets.is_empty() && custom_embed_targets.is_empty() { + if target_stream.is_empty() { return Ok(TokenStream::new()); } - // Determine whether the Embeddable::Kind should be SinleEmbedding or ManyEmbedding. - let embed_kind = match &input.data { - syn::Data::Struct(data_struct) => embed_kind(data_struct)?, - _ => { - return Err(syn::Error::new_spanned( - name, - "Embeddable derive macro should only be used on structs", - )) - } - }; - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let name = &input.ident; + let gen = quote! { // Note: Embeddable trait is imported with the macro. impl #impl_generics Embeddable for #name #ty_generics #where_clause { type Kind = rig::embeddings::embeddable::#embed_kind; - type Error = rig::embeddings::embeddable::EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { - let mut embed_targets = #embed_targets_quote; + fn embeddable(&self) -> Result, rig::embeddings::embeddable::EmbeddableError> { + #target_stream; - let custom_embed_targets = #custom_embed_targets_quote; + let targets = embed_targets.into_iter() + .collect::, _>>()? + .into_iter() + .flatten() + .collect::>(); - embed_targets.extend(custom_embed_targets); - - Ok(embed_targets) + Ok(targets) } } }; @@ -152,3 +97,71 @@ fn embed_kind(data_struct: &DataStruct) -> syn::Result { parse_str(MANY_EMBEDDING) } } + +trait StructParser { + // Handles fields tagged with #[embed] + fn basic(&self, generics: &mut syn::Generics) -> TokenStream; + + // Handles fields tagged with #[embed(embed_with = "...")] + fn custom(&self) -> syn::Result; +} + +impl StructParser for DataStruct { + fn basic(&self, generics: &mut syn::Generics) -> TokenStream { + let embed_targets = basic_embed_fields(self) + // Iterate over every field tagged with #[embed] + .map(|field| { + add_struct_bounds(generics, &field.ty); + + let field_name = field.ident; + + quote! { + self.#field_name + } + }) + .collect::>(); + + if !embed_targets.is_empty() { + quote! { + vec![#(#embed_targets.embeddable()),*] + // .into_iter() + // .collect::, _>>()? + // .into_iter() + // .flatten() + // .collect::>() + } + } else { + quote! { + vec![] + } + } + } + + fn custom(&self) -> syn::Result { + let embed_targets = custom_embed_fields(self)? + // Iterate over every field tagged with #[embed(embed_with = "...")] + .map(|(field, custom_func_path)| { + let field_name = field.ident; + + quote! { + #custom_func_path(self.#field_name.clone()) + } + }) + .collect::>(); + + Ok(if !embed_targets.is_empty() { + quote! { + vec![#(#embed_targets),*] + // .into_iter() + // .collect::, _>>()? + // .into_iter() + // .flatten() + // .collect::>() + } + } else { + quote! { + vec![] + } + }) + } +} diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 01a99fd4..1a699833 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -65,7 +65,7 @@ use futures::{stream, StreamExt, TryStreamExt}; use crate::Embeddable; use super::{ - embeddable::{EmbeddingKind, ManyEmbedding, SingleEmbedding}, + embeddable::{EmbeddableError, EmbeddingKind, ManyEmbedding, SingleEmbedding}, embedding::{Embedding, EmbeddingError, EmbeddingModel}, }; @@ -87,7 +87,7 @@ impl, K: EmbeddingKind> EmbeddingsBui } /// Add a document that implements `Embeddable` to the builder. - pub fn document(mut self, document: D) -> Result { + pub fn document(mut self, document: D) -> Result { let embed_targets = document.embeddable()?; self.documents.push((document, embed_targets)); @@ -95,7 +95,7 @@ impl, K: EmbeddingKind> EmbeddingsBui } /// Add many documents that implement `Embeddable` to the builder. - pub fn documents(mut self, documents: Vec) -> Result { + pub fn documents(mut self, documents: Vec) -> Result { for doc in documents.into_iter() { let embed_targets = doc.embeddable()?; diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index 80e5b689..6b1996e3 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -22,9 +22,8 @@ pub enum EmbeddableError { /// If the type `Kind` is `SingleEmbedding`, the list of strings contains a single item, otherwise, the list can contain many items. pub trait Embeddable { type Kind: EmbeddingKind; - type Error: std::error::Error; - fn embeddable(&self) -> Result, Self::Error>; + fn embeddable(&self) -> Result, EmbeddableError>; } ////////////////////////////////////////////////////// @@ -32,99 +31,98 @@ pub trait Embeddable { ////////////////////////////////////////////////////// impl Embeddable for String { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { + fn embeddable(&self) -> Result, EmbeddableError> { Ok(vec![self.clone()]) } } impl Embeddable for i8 { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { + fn embeddable(&self) -> Result, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for i16 { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { + fn embeddable(&self) -> Result, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for i32 { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { + fn embeddable(&self) -> Result, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for i64 { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { + fn embeddable(&self) -> Result, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for i128 { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { + fn embeddable(&self) -> Result, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for f32 { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { + fn embeddable(&self) -> Result, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for f64 { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { + fn embeddable(&self) -> Result, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for bool { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { + fn embeddable(&self) -> Result, EmbeddableError> { Ok(vec![self.to_string()]) } } impl Embeddable for char { type Kind = SingleEmbedding; - type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { + fn embeddable(&self) -> Result, EmbeddableError> { Ok(vec![self.to_string()]) } } +impl Embeddable for serde_json::Value { + type Kind = SingleEmbedding; + + fn embeddable(&self) -> Result, EmbeddableError> { + Ok(vec![ + serde_json::to_string(self).map_err(EmbeddableError::SerdeError)? + ]) + } +} + impl Embeddable for Vec { type Kind = ManyEmbedding; - type Error = T::Error; - fn embeddable(&self) -> Result, Self::Error> { + fn embeddable(&self) -> Result, EmbeddableError> { Ok(self .iter() .map(|i| i.embeddable()) diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 37e720cb..d590b7d0 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -6,3 +6,4 @@ pub mod builder; pub mod embeddable; pub mod embedding; +pub mod tool; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs new file mode 100644 index 00000000..c369fd94 --- /dev/null +++ b/rig-core/src/embeddings/tool.rs @@ -0,0 +1,25 @@ +use crate::{self as rig, tool::ToolEmbeddingDyn}; +use rig::embeddings::embeddable::Embeddable; +use rig_derive::Embeddable; +use serde::Serialize; + +use super::embeddable::EmbeddableError; + +/// Used by EmbeddingsBuilder to embed anything that implements ToolEmbedding. +#[derive(Embeddable, Clone, Serialize, Default, Eq, PartialEq)] +pub struct EmbeddableTool { + name: String, + #[embed] + context: serde_json::Value, +} + +impl EmbeddableTool { + /// Convert item that implements ToolEmbedding to an EmbeddableTool. + pub fn try_from(tool: &Box) -> Result { + Ok(EmbeddableTool { + name: tool.name(), + context: serde_json::to_value(tool.context().map_err(EmbeddableError::SerdeError)?) + .map_err(EmbeddableError::SerdeError)?, + }) + } +} diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 98394ecf..181751f5 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -3,7 +3,10 @@ use std::{collections::HashMap, pin::Pin}; use futures::Future; use serde::{Deserialize, Serialize}; -use crate::completion::{self, ToolDefinition}; +use crate::{ + completion::{self, ToolDefinition}, + embeddings::{embeddable::EmbeddableError, tool::EmbeddableTool}, +}; #[derive(Debug, thiserror::Error)] pub enum ToolError { @@ -323,6 +326,19 @@ impl ToolSet { } Ok(docs) } + + pub fn embedabble_tools(&self) -> Result, EmbeddableError> { + self.tools + .values() + .filter_map(|tool_type| { + if let ToolType::Embedding(tool) = tool_type { + Some(EmbeddableTool::try_from(tool)) + } else { + None + } + }) + .collect::, _>>() + } } #[derive(Default)] diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index f1e6ad0a..9bbc85f1 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -189,7 +189,7 @@ impl Vec mod tests { use std::cmp::Reverse; - use crate::embeddings::Embedding; + use crate::embeddings::embedding::Embedding; use super::{InMemoryVectorStore, RankingItem}; diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index e087c3d6..a816c9c0 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,5 +1,4 @@ use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; -use rig::embeddings::embeddable::EmbeddableError; use rig::providers::openai::TEXT_EMBEDDING_ADA_002; use serde::{Deserialize, Serialize}; use std::env; From dc89e54f9e0bc7c98da3dbdca014691f42b764ca Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 15 Oct 2024 15:27:37 -0400 Subject: [PATCH 17/47] bug(macro): fix error when embed tags missing --- rig-core/rig-core-derive/src/custom.rs | 29 ++++---- rig-core/rig-core-derive/src/embeddable.rs | 82 ++++++++++++---------- rig-core/src/embeddings/tool.rs | 2 +- rig-core/src/tool.rs | 2 +- 4 files changed, 61 insertions(+), 54 deletions(-) diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs index 194be085..a795026c 100644 --- a/rig-core/rig-core-derive/src/custom.rs +++ b/rig-core/rig-core-derive/src/custom.rs @@ -9,29 +9,32 @@ const EMBED_WITH: &str = "embed_with"; /// Also returns the "..." part of the tag (ie. the custom function). pub(crate) fn custom_embed_fields( data_struct: &syn::DataStruct, -) -> syn::Result> { - Ok(data_struct +) -> syn::Result> { + data_struct .fields .clone() .into_iter() - .map(|field| { + .filter_map(|field| { field .attrs .clone() .into_iter() - .map(|attribute| { - if attribute.is_custom()? { - Ok::<_, syn::Error>(Some((field.clone(), attribute.expand_tag()?))) - } else { - Ok(None) + .filter_map(|attribute| { + match attribute.is_custom() { + Ok(true) => { + match attribute.expand_tag() { + Ok(path) => Some(Ok((field.clone(), path))), + Err(e) => Some(Err(e)), + } + }, + Ok(false) => None, + Err(e) => Some(Err(e)) } }) - .collect::, _>>() + .next() }) - .collect::, _>>()? - .into_iter() - .flatten() - .flatten()) + .collect::, _>>() + } trait CustomAttributeParser { diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index cd5a201f..d44bb7bf 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -11,15 +11,25 @@ const MANY_EMBEDDING: &str = "ManyEmbedding"; const SINGLE_EMBEDDING: &str = "SingleEmbedding"; pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result { + let name = &input.ident; let data = &input.data; let generics = &mut input.generics; let (target_stream, embed_kind) = match data { syn::Data::Struct(data_struct) => { - let basic_targets = data_struct.basic(generics); - let custom_targets = data_struct.custom()?; + let (basic_targets, basic_target_size) = data_struct.basic(generics); + let (custom_targets, custom_target_size) = data_struct.custom()?; + + // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. + // ie. do not implement Embeddable trait for the struct. + if basic_target_size + custom_target_size == 0 { + return Err(syn::Error::new_spanned( + name, + "Add at least one field tagged with #[embed] or #[embed(embed_with = \"...\")].", + )) + } - // Determine whether the Embeddable::Kind should be SinleEmbedding or ManyEmbedding + // Determine whether the Embeddable::Kind should be SingleEmbedding or ManyEmbedding ( quote! { let mut embed_targets = #basic_targets; @@ -36,16 +46,8 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu } }; - // If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream. - // ie. do not implement Embeddable trait for the struct. - if target_stream.is_empty() { - return Ok(TokenStream::new()); - } - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let name = &input.ident; - let gen = quote! { // Note: Embeddable trait is imported with the macro. @@ -88,7 +90,7 @@ fn embed_kind(data_struct: &DataStruct) -> syn::Result { } } let fields = basic_embed_fields(data_struct) - .chain(custom_embed_fields(data_struct)?.map(|(f, _)| f)) + .chain(custom_embed_fields(data_struct)?.into_iter().map(|(f, _)| f)) .collect::>(); if fields.len() == 1 { @@ -100,14 +102,14 @@ fn embed_kind(data_struct: &DataStruct) -> syn::Result { trait StructParser { // Handles fields tagged with #[embed] - fn basic(&self, generics: &mut syn::Generics) -> TokenStream; + fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize); // Handles fields tagged with #[embed(embed_with = "...")] - fn custom(&self) -> syn::Result; + fn custom(&self) -> syn::Result<(TokenStream, usize)>; } impl StructParser for DataStruct { - fn basic(&self, generics: &mut syn::Generics) -> TokenStream { + fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize) { let embed_targets = basic_embed_fields(self) // Iterate over every field tagged with #[embed] .map(|field| { @@ -122,25 +124,26 @@ impl StructParser for DataStruct { .collect::>(); if !embed_targets.is_empty() { - quote! { - vec![#(#embed_targets.embeddable()),*] - // .into_iter() - // .collect::, _>>()? - // .into_iter() - // .flatten() - // .collect::>() - } + ( + quote! { + vec![#(#embed_targets.embeddable()),*] + }, + embed_targets.len() + ) } else { - quote! { - vec![] - } + ( + quote! { + vec![] + }, + 0 + ) } } - fn custom(&self) -> syn::Result { + fn custom(&self) -> syn::Result<(TokenStream, usize)> { let embed_targets = custom_embed_fields(self)? // Iterate over every field tagged with #[embed(embed_with = "...")] - .map(|(field, custom_func_path)| { + .into_iter().map(|(field, custom_func_path)| { let field_name = field.ident; quote! { @@ -150,18 +153,19 @@ impl StructParser for DataStruct { .collect::>(); Ok(if !embed_targets.is_empty() { - quote! { - vec![#(#embed_targets),*] - // .into_iter() - // .collect::, _>>()? - // .into_iter() - // .flatten() - // .collect::>() - } + ( + quote! { + vec![#(#embed_targets),*] + }, + embed_targets.len() + ) } else { - quote! { - vec![] - } + ( + quote! { + vec![] + }, + 0 + ) }) } } diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs index c369fd94..5e8ecd35 100644 --- a/rig-core/src/embeddings/tool.rs +++ b/rig-core/src/embeddings/tool.rs @@ -15,7 +15,7 @@ pub struct EmbeddableTool { impl EmbeddableTool { /// Convert item that implements ToolEmbedding to an EmbeddableTool. - pub fn try_from(tool: &Box) -> Result { + pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result { Ok(EmbeddableTool { name: tool.name(), context: serde_json::to_value(tool.context().map_err(EmbeddableError::SerdeError)?) diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 181751f5..d6f05ffc 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -332,7 +332,7 @@ impl ToolSet { .values() .filter_map(|tool_type| { if let ToolType::Embedding(tool) = tool_type { - Some(EmbeddableTool::try_from(tool)) + Some(EmbeddableTool::try_from(&**tool)) } else { None } From ae66d082487b5b7b2a0bb7a456ae86f3433666e9 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 15 Oct 2024 15:28:35 -0400 Subject: [PATCH 18/47] style: cargo fmt --- rig-core/rig-core-derive/src/custom.rs | 19 +++++++------------ rig-core/rig-core-derive/src/embeddable.rs | 19 ++++++++++++------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs index a795026c..de754372 100644 --- a/rig-core/rig-core-derive/src/custom.rs +++ b/rig-core/rig-core-derive/src/custom.rs @@ -19,22 +19,17 @@ pub(crate) fn custom_embed_fields( .attrs .clone() .into_iter() - .filter_map(|attribute| { - match attribute.is_custom() { - Ok(true) => { - match attribute.expand_tag() { - Ok(path) => Some(Ok((field.clone(), path))), - Err(e) => Some(Err(e)), - } - }, - Ok(false) => None, - Err(e) => Some(Err(e)) - } + .filter_map(|attribute| match attribute.is_custom() { + Ok(true) => match attribute.expand_tag() { + Ok(path) => Some(Ok((field.clone(), path))), + Err(e) => Some(Err(e)), + }, + Ok(false) => None, + Err(e) => Some(Err(e)), }) .next() }) .collect::, _>>() - } trait CustomAttributeParser { diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index d44bb7bf..88503360 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -26,7 +26,7 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu return Err(syn::Error::new_spanned( name, "Add at least one field tagged with #[embed] or #[embed(embed_with = \"...\")].", - )) + )); } // Determine whether the Embeddable::Kind should be SingleEmbedding or ManyEmbedding @@ -90,7 +90,11 @@ fn embed_kind(data_struct: &DataStruct) -> syn::Result { } } let fields = basic_embed_fields(data_struct) - .chain(custom_embed_fields(data_struct)?.into_iter().map(|(f, _)| f)) + .chain( + custom_embed_fields(data_struct)? + .into_iter() + .map(|(f, _)| f), + ) .collect::>(); if fields.len() == 1 { @@ -128,14 +132,14 @@ impl StructParser for DataStruct { quote! { vec![#(#embed_targets.embeddable()),*] }, - embed_targets.len() + embed_targets.len(), ) } else { ( quote! { vec![] }, - 0 + 0, ) } } @@ -143,7 +147,8 @@ impl StructParser for DataStruct { fn custom(&self) -> syn::Result<(TokenStream, usize)> { let embed_targets = custom_embed_fields(self)? // Iterate over every field tagged with #[embed(embed_with = "...")] - .into_iter().map(|(field, custom_func_path)| { + .into_iter() + .map(|(field, custom_func_path)| { let field_name = field.ident; quote! { @@ -157,14 +162,14 @@ impl StructParser for DataStruct { quote! { vec![#(#embed_targets),*] }, - embed_targets.len() + embed_targets.len(), ) } else { ( quote! { vec![] }, - 0 + 0, ) }) } From 4305952beeff3b62039b8e9374ff2855d9142c7e Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 15 Oct 2024 15:34:11 -0400 Subject: [PATCH 19/47] fix(tests): clippy --- rig-core/src/embeddings/embeddable.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index 6b1996e3..5faa8b2d 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -173,6 +173,11 @@ mod tests { }, }; + println!( + "FakeDefinition: {}, {}", + fake_definition.id, fake_definition.word + ); + assert_eq!( fake_definition.embeddable().unwrap(), vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] @@ -188,13 +193,18 @@ mod tests { } #[test] - fn test_simple_embed() { + fn test_single_embed() { let fake_definition = FakeDefinition2 { id: "doc1".to_string(), word: "house".to_string(), definition: "a building in which people live; residence for human beings.".to_string(), }; + println!( + "FakeDefinition2: {}, {}", + fake_definition.id, fake_definition.word + ); + assert_eq!( fake_definition.embeddable().unwrap(), vec!["a building in which people live; residence for human beings.".to_string()] @@ -219,6 +229,8 @@ mod tests { employee_ages: vec![25, 30, 35, 40], }; + println!("Company: {}, {}", company.id, company.company); + assert_eq!( company.embeddable().unwrap(), vec![ @@ -247,6 +259,8 @@ mod tests { employee_ages: vec![25, 30, 35, 40], }; + println!("Company2: {}", company.id); + assert_eq!( company.embeddable().unwrap(), vec![ From 24e3b9867382f53aab7cb6c87d2ace46a05c23ce Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 15 Oct 2024 16:04:24 -0400 Subject: [PATCH 20/47] docs&revert: revert embeddable trait error type, add docstrings --- rig-core/rig-core-derive/src/embeddable.rs | 3 +- rig-core/src/embeddings/builder.rs | 85 +++++++++++----------- rig-core/src/embeddings/embeddable.rs | 60 +++++++++++---- rig-core/src/tool.rs | 3 + 4 files changed, 96 insertions(+), 55 deletions(-) diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index 88503360..e3fe7f94 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -53,8 +53,9 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu impl #impl_generics Embeddable for #name #ty_generics #where_clause { type Kind = rig::embeddings::embeddable::#embed_kind; + type Error = rig::embeddings::embeddable::EmbeddableError; - fn embeddable(&self) -> Result, rig::embeddings::embeddable::EmbeddableError> { + fn embeddable(&self) -> Result, Self::Error> { #target_stream; let targets = embed_targets.into_iter() diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 1a699833..435ecc2f 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -6,18 +6,23 @@ //! use std::env; //! //! use rig::{ -//! embeddings::EmbeddingsBuilder, -//! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, +//! embeddings::builder::EmbeddingsBuilder, +//! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, +//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, +//! Embeddable, //! }; -//! use rig_derive::Embed; +//! use serde::{Deserialize, Serialize}; //! -//! #[derive(Embed)] +//! // Shape of data that needs to be RAG'ed. +//! // The definition field will be used to generate embeddings. +//! #[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] //! struct FakeDefinition { -//! id: String, -//! word: String, -//! #[embed] -//! definitions: Vec, +//! id: String, +//! word: String, +//! #[embed] +//! definitions: Vec, //! } +//! //! // Create OpenAI client //! let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); //! let openai_client = Client::new(&openai_api_key); @@ -25,34 +30,34 @@ //! let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); //! //! let embeddings = EmbeddingsBuilder::new(model.clone()) -//! .documents(vec![ -//! FakeDefinition { -//! id: "doc0".to_string(), -//! word: "flurbo".to_string(), -//! definitions: vec![ -//! "A green alien that lives on cold planets.".to_string(), -//! "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() -//! ] -//! }, -//! FakeDefinition { -//! id: "doc1".to_string(), -//! word: "glarb-glarb".to_string(), -//! definitions: vec![ -//! "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), -//! "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() -//! ] -//! }, -//! FakeDefinition { -//! id: "doc2".to_string(), -//! word: "linglingdong".to_string(), -//! definitions: vec![ -//! "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), -//! "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() -//! ] -//! }, -//! ]) -//! .build() -//! .await?; +//! .documents(vec![ +//! FakeDefinition { +//! id: "doc0".to_string(), +//! word: "flurbo".to_string(), +//! definitions: vec![ +//! "A green alien that lives on cold planets.".to_string(), +//! "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() +//! ] +//! }, +//! FakeDefinition { +//! id: "doc1".to_string(), +//! word: "glarb-glarb".to_string(), +//! definitions: vec![ +//! "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), +//! "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() +//! ] +//! }, +//! FakeDefinition { +//! id: "doc2".to_string(), +//! word: "linglingdong".to_string(), +//! definitions: vec![ +//! "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), +//! "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() +//! ] +//! }, +//! ])? +//! .build() +//! .await?; //! //! // Use the generated embeddings //! // ... @@ -62,10 +67,8 @@ use std::{cmp::max, collections::HashMap, marker::PhantomData}; use futures::{stream, StreamExt, TryStreamExt}; -use crate::Embeddable; - use super::{ - embeddable::{EmbeddableError, EmbeddingKind, ManyEmbedding, SingleEmbedding}, + embeddable::{Embeddable, EmbeddableError, EmbeddingKind, ManyEmbedding, SingleEmbedding}, embedding::{Embedding, EmbeddingError, EmbeddingModel}, }; @@ -87,7 +90,7 @@ impl, K: EmbeddingKind> EmbeddingsBui } /// Add a document that implements `Embeddable` to the builder. - pub fn document(mut self, document: D) -> Result { + pub fn document(mut self, document: D) -> Result { let embed_targets = document.embeddable()?; self.documents.push((document, embed_targets)); @@ -95,7 +98,7 @@ impl, K: EmbeddingKind> EmbeddingsBui } /// Add many documents that implement `Embeddable` to the builder. - pub fn documents(mut self, documents: Vec) -> Result { + pub fn documents(mut self, documents: Vec) -> Result { for doc in documents.into_iter() { let embed_targets = doc.embeddable()?; diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index 5faa8b2d..b46ed28c 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -1,4 +1,22 @@ //! The module defines the [Embeddable] trait, which must be implemented for types that can be embedded. +//! //! # Example +//! ```rust +//! use std::env; +//! +//! use rig::Embeddable; +//! use serde::{Deserialize, Serialize}; +//! +//! #[derive(Embeddable)] +//! struct FakeDefinition { +//! id: String, +//! word: String, +//! #[embed] +//! definitions: Vec, +//! } +//! +//! // Do something with FakeDefinition +//! // ... +//! ``` /// The associated type `Kind` on the trait `Embeddable` must implement this trait. pub trait EmbeddingKind {} @@ -11,6 +29,8 @@ impl EmbeddingKind for SingleEmbedding {} pub struct ManyEmbedding; impl EmbeddingKind for ManyEmbedding {} +/// Error type used for when the `embeddable` method fails. +/// Used by default implementations of `Embeddable` for common types. #[derive(Debug, thiserror::Error)] pub enum EmbeddableError { #[error("SerdeError: {0}")] @@ -20,10 +40,12 @@ pub enum EmbeddableError { /// Trait for types that can be embedded. /// The `embeddable` method returns a list of strings for which embeddings will be generated by the embeddings builder. /// If the type `Kind` is `SingleEmbedding`, the list of strings contains a single item, otherwise, the list can contain many items. +/// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. pub trait Embeddable { type Kind: EmbeddingKind; + type Error: std::error::Error; - fn embeddable(&self) -> Result, EmbeddableError>; + fn embeddable(&self) -> Result, Self::Error>; } ////////////////////////////////////////////////////// @@ -31,88 +53,99 @@ pub trait Embeddable { ////////////////////////////////////////////////////// impl Embeddable for String { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result, EmbeddableError> { + fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.clone()]) } } impl Embeddable for i8 { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result, EmbeddableError> { + fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for i16 { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result, EmbeddableError> { + fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for i32 { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result, EmbeddableError> { + fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for i64 { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result, EmbeddableError> { + fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for i128 { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result, EmbeddableError> { + fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for f32 { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result, EmbeddableError> { + fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for f64 { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result, EmbeddableError> { + fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for bool { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result, EmbeddableError> { + fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for char { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result, EmbeddableError> { + fn embeddable(&self) -> Result, Self::Error> { Ok(vec![self.to_string()]) } } impl Embeddable for serde_json::Value { type Kind = SingleEmbedding; + type Error = EmbeddableError; - fn embeddable(&self) -> Result, EmbeddableError> { + fn embeddable(&self) -> Result, Self::Error> { Ok(vec![ serde_json::to_string(self).map_err(EmbeddableError::SerdeError)? ]) @@ -121,8 +154,9 @@ impl Embeddable for serde_json::Value { impl Embeddable for Vec { type Kind = ManyEmbedding; + type Error = T::Error; - fn embeddable(&self) -> Result, EmbeddableError> { + fn embeddable(&self) -> Result, Self::Error> { Ok(self .iter() .map(|i| i.embeddable()) diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index d6f05ffc..e92896b8 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -327,6 +327,9 @@ impl ToolSet { Ok(docs) } + /// Convert tools in self to objects of type EmbeddableTool. + /// This is necessary because when adding tools to the EmbeddingBuilder because all + /// documents added to the builder must all be of the same type. pub fn embedabble_tools(&self) -> Result, EmbeddableError> { self.tools .values() From a7dbf6cd27be89b100144e261e40f38a9834fbb8 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 15 Oct 2024 16:06:13 -0400 Subject: [PATCH 21/47] style: cargo clippy --- rig-core/rig-core-derive/src/embeddable.rs | 1 - rig-core/src/embeddings/builder.rs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index e3fe7f94..7c3c883e 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -68,7 +68,6 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu } } }; - eprintln!("Generated code:\n{}", gen); Ok(gen) } diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 435ecc2f..ea525e8b 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -68,7 +68,7 @@ use std::{cmp::max, collections::HashMap, marker::PhantomData}; use futures::{stream, StreamExt, TryStreamExt}; use super::{ - embeddable::{Embeddable, EmbeddableError, EmbeddingKind, ManyEmbedding, SingleEmbedding}, + embeddable::{Embeddable, EmbeddingKind, ManyEmbedding, SingleEmbedding}, embedding::{Embedding, EmbeddingError, EmbeddingModel}, }; From 886ebcb6ba9c1b1c2211dad33edc8a669966b615 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 15 Oct 2024 16:11:11 -0400 Subject: [PATCH 22/47] clippy(lancedb): fix unused function error --- rig-lancedb/examples/fixtures/lib.rs | 11 ++--------- rig-lancedb/examples/vector_search_local_ann.rs | 7 +++++-- rig-lancedb/examples/vector_search_s3_ann.rs | 7 +++++-- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 956ace1c..415422f0 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -8,9 +8,9 @@ use serde::Deserialize; #[derive(Embeddable, Clone, Deserialize, Debug)] pub struct FakeDefinition { - id: String, + pub id: String, #[embed] - definition: String, + pub definition: String, } pub fn fake_definitions() -> Vec { @@ -30,13 +30,6 @@ pub fn fake_definitions() -> Vec { ] } -pub fn fake_definition(id: String) -> FakeDefinition { - FakeDefinition { - id, - definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() - } -} - // Schema of table in LanceDB. pub fn schema(dims: usize) -> Schema { Schema::new(Fields::from(vec![ diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 1042e34c..1b7870fb 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDefinition}; +use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ @@ -31,7 +31,10 @@ async fn main() -> Result<(), anyhow::Error> { // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) - .map(|i| fake_definition(format!("doc{}", i))) + .map(|i| FakeDefinition { + id: format!("doc{}", i), + definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() + }) .collect(), )? .build() diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 8d65e37a..8c10409b 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definition, fake_definitions, schema, FakeDefinition}; +use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, @@ -37,7 +37,10 @@ async fn main() -> Result<(), anyhow::Error> { // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. .documents( (0..256) - .map(|i| fake_definition(format!("doc{}", i))) + .map(|i| FakeDefinition { + id: format!("doc{}", i), + definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() + }) .collect(), )? .build() From 79dea4536c1ed96b60f49d8234d96445e9c2e9c9 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 15 Oct 2024 16:13:25 -0400 Subject: [PATCH 23/47] fix(test): remove useless assert false statement --- rig-core/src/embeddings/embeddable.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index b46ed28c..c7b7bf0c 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -242,9 +242,7 @@ mod tests { assert_eq!( fake_definition.embeddable().unwrap(), vec!["a building in which people live; residence for human beings.".to_string()] - ); - - assert!(false) + ) } #[derive(Embeddable)] From 636234458755ca265e324fa1ab252d40e99232ab Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 16 Oct 2024 14:32:00 -0400 Subject: [PATCH 24/47] cleanup: split up branch into 2 branches for readability --- Cargo.lock | 2 - rig-core/examples/calculator_chatbot.rs | 17 +- rig-core/examples/rag.rs | 51 +-- rig-core/examples/rag_dynamic_tools.rs | 13 +- rig-core/examples/vector_search.rs | 59 +--- rig-core/examples/vector_search_cohere.rs | 62 +--- rig-core/src/embeddings/builder.rs | 332 ++++++++++-------- rig-core/src/embeddings/mod.rs | 3 +- rig-core/src/embeddings/tool.rs | 25 -- rig-core/src/providers/cohere.rs | 10 +- rig-core/src/providers/openai.rs | 18 +- rig-core/src/tool.rs | 18 +- rig-lancedb/Cargo.toml | 1 - rig-lancedb/examples/fixtures/lib.rs | 56 ++- .../examples/vector_search_local_ann.rs | 33 +- .../examples/vector_search_local_enn.rs | 9 +- rig-lancedb/examples/vector_search_s3_ann.rs | 33 +- rig-mongodb/Cargo.toml | 2 - rig-mongodb/examples/vector_search_mongodb.rs | 82 +---- rig-mongodb/src/lib.rs | 39 +- 20 files changed, 350 insertions(+), 515 deletions(-) delete mode 100644 rig-core/src/embeddings/tool.rs diff --git a/Cargo.lock b/Cargo.lock index c47f71b8..75f67709 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4026,7 +4026,6 @@ dependencies = [ "futures", "lancedb", "rig-core", - "rig-derive", "serde", "serde_json", "tokio", @@ -4040,7 +4039,6 @@ dependencies = [ "futures", "mongodb", "rig-core", - "rig-derive", "serde", "serde_json", "tokio", diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 0b994265..949073b3 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -2,7 +2,7 @@ use anyhow::Result; use rig::{ cli_chatbot::cli_chatbot, completion::ToolDefinition, - embeddings::builder::EmbeddingsBuilder, + embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, @@ -25,7 +25,7 @@ struct MathError; #[error("Init error")] struct InitError; -#[derive(Deserialize, Serialize, Clone)] +#[derive(Deserialize, Serialize)] struct Add; impl Tool for Add { const NAME: &'static str = "add"; @@ -77,7 +77,7 @@ impl ToolEmbedding for Add { fn context(&self) -> Self::Context {} } -#[derive(Deserialize, Serialize, Clone)] +#[derive(Deserialize, Serialize)] struct Subtract; impl Tool for Subtract { const NAME: &'static str = "subtract"; @@ -247,7 +247,7 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(toolset.embedabble_tools()?)? + .tools(&toolset)? .build() .await?; @@ -255,8 +255,13 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .enumerate() - .map(|(i, (tool, embedding))| (i.to_string(), tool, vec![embedding])) + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) .collect(), )? .index(embedding_model); diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 43270a7b..936a3d05 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -1,22 +1,11 @@ -use std::{env, vec}; +use std::env; use rig::{ completion::Prompt, - embeddings::builder::EmbeddingsBuilder, + embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, - Embeddable, }; -use serde::Serialize; - -// Shape of data that needs to be RAG'ed. -// The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Debug, Serialize, Eq, PartialEq, Default)] -struct FakeDefinition { - id: String, - #[embed] - definitions: Vec, -} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -27,29 +16,9 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(vec![ - FakeDefinition { - id: "doc0".to_string(), - definitions: vec![ - "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets.".to_string(), - "Definition of a *flurbo*: A fictional digital currency that originated in the animated series Rick and Morty.".to_string() - ] - }, - FakeDefinition { - id: "doc1".to_string(), - definitions: vec![ - "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "Definition of a *glarb-glarb*: A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() - ] - }, - FakeDefinition { - id: "doc2".to_string(), - definitions: vec![ - "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), - "Definition of a *linglingdong*: A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() - ] - }, - ])? + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") .build() .await?; @@ -57,9 +26,13 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) .collect(), )? .index(embedding_model); diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 51c56ca8..f00543a2 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -1,7 +1,7 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, - embeddings::builder::EmbeddingsBuilder, + embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, @@ -156,7 +156,7 @@ async fn main() -> Result<(), anyhow::Error> { .build(); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(toolset.embedabble_tools()?)? + .tools(&toolset)? .build() .await?; @@ -164,8 +164,13 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .enumerate() - .map(|(i, (tool, embedding))| (i.to_string(), tool, vec![embedding])) + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) .collect(), )? .index(embedding_model); diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index a97ef8b0..26692706 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,22 +1,10 @@ use std::env; use rig::{ - embeddings::builder::EmbeddingsBuilder, + embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - Embeddable, }; -use serde::{Deserialize, Serialize}; - -// Shape of data that needs to be RAG'ed. -// The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -struct FakeDefinition { - id: String, - word: String, - #[embed] - definitions: Vec, -} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -27,32 +15,9 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(vec![ - FakeDefinition { - id: "doc0".to_string(), - word: "flurbo".to_string(), - definitions: vec![ - "A green alien that lives on cold planets.".to_string(), - "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() - ] - }, - FakeDefinition { - id: "doc1".to_string(), - word: "glarb-glarb".to_string(), - definitions: vec![ - "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() - ] - }, - FakeDefinition { - id: "doc2".to_string(), - word: "linglingdong".to_string(), - definitions: vec![ - "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), - "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() - ] - }, - ])? + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") .build() .await?; @@ -60,24 +25,28 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) .collect(), )? .index(model); let results = index - .top_n::("I need to buy something in a fictional universe. What type of money can I use for this?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.word)) + .map(|(score, id, doc)| (score, id, doc)) .collect::>(); println!("Results: {:?}", results); let id_results = index - .top_n_ids("I need to buy something in a fictional universe. What type of money can I use for this?", 1) + .top_n_ids("What is a linglingdong?", 1) .await? .into_iter() .map(|(score, id)| (score, id)) diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 16ddb775..c14fe0ce 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,22 +1,10 @@ use std::env; use rig::{ - embeddings::builder::EmbeddingsBuilder, + embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - Embeddable, }; -use serde::{Deserialize, Serialize}; - -// Shape of data that needs to be RAG'ed. -// The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -struct FakeDefinition { - id: String, - word: String, - #[embed] - definitions: Vec, -} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -27,33 +15,10 @@ async fn main() -> Result<(), anyhow::Error> { let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document"); let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query"); - let embeddings = EmbeddingsBuilder::new(document_model.clone()) - .documents(vec![ - FakeDefinition { - id: "doc0".to_string(), - word: "flurbo".to_string(), - definitions: vec![ - "A green alien that lives on cold planets.".to_string(), - "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() - ] - }, - FakeDefinition { - id: "doc1".to_string(), - word: "glarb-glarb".to_string(), - definitions: vec![ - "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() - ] - }, - FakeDefinition { - id: "doc2".to_string(), - word: "linglingdong".to_string(), - definitions: vec![ - "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), - "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() - ] - }, - ])? + let embeddings = EmbeddingsBuilder::new(document_model) + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") .build() .await?; @@ -61,21 +26,22 @@ async fn main() -> Result<(), anyhow::Error> { .add_documents( embeddings .into_iter() - .map(|(fake_definition, embedding_vec)| { - (fake_definition.id.clone(), fake_definition, embedding_vec) - }) + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) .collect(), )? .index(search_model); let results = index - .top_n::( - "Which instrument is found in the Nebulon Mountain Ranges?", - 1, - ) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.word)) + .map(|(score, id, doc)| (score, id, doc)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index ea525e8b..50e56537 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -3,207 +3,233 @@ //! //! # Example //! ```rust -//! use std::env; +//! use rig::providers::openai::{Client, self}; +//! use rig::embeddings::{EmbeddingModel, EmbeddingsBuilder}; //! -//! use rig::{ -//! embeddings::builder::EmbeddingsBuilder, -//! providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, -//! vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, -//! Embeddable, -//! }; -//! use serde::{Deserialize, Serialize}; +//! // Initialize the OpenAI client +//! let openai = Client::new("your-openai-api-key"); //! -//! // Shape of data that needs to be RAG'ed. -//! // The definition field will be used to generate embeddings. -//! #[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] -//! struct FakeDefinition { -//! id: String, -//! word: String, -//! #[embed] -//! definitions: Vec, -//! } +//! // Create an instance of the `text-embedding-ada-002` model +//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002); //! -//! // Create OpenAI client -//! let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); -//! let openai_client = Client::new(&openai_api_key); -//! -//! let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); -//! -//! let embeddings = EmbeddingsBuilder::new(model.clone()) -//! .documents(vec![ -//! FakeDefinition { -//! id: "doc0".to_string(), -//! word: "flurbo".to_string(), -//! definitions: vec![ -//! "A green alien that lives on cold planets.".to_string(), -//! "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() -//! ] -//! }, -//! FakeDefinition { -//! id: "doc1".to_string(), -//! word: "glarb-glarb".to_string(), -//! definitions: vec![ -//! "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), -//! "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() -//! ] -//! }, -//! FakeDefinition { -//! id: "doc2".to_string(), -//! word: "linglingdong".to_string(), -//! definitions: vec![ -//! "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), -//! "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() -//! ] -//! }, -//! ])? +//! // Create an embeddings builder and add documents +//! let embeddings = EmbeddingsBuilder::new(embedding_model) +//! .simple_document("doc1", "This is the first document.") +//! .simple_document("doc2", "This is the second document.") //! .build() -//! .await?; +//! .await +//! .expect("Failed to build embeddings."); //! //! // Use the generated embeddings //! // ... //! ``` -use std::{cmp::max, collections::HashMap, marker::PhantomData}; +use std::{cmp::max, collections::HashMap}; use futures::{stream, StreamExt, TryStreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::tool::{ToolEmbedding, ToolSet, ToolType}; + +use super::embedding::{ Embedding, EmbeddingError, EmbeddingModel}; + +/// Struct that holds a document and its embeddings. +/// +/// The struct is designed to model any kind of documents that can be serialized to JSON +/// (including a simple string). +/// +/// Moreover, it can hold multiple embeddings for the same document, thus allowing a +/// large document to be retrieved from a query that matches multiple smaller and +/// distinct text documents. For example, if the document is a textbook, a summary of +/// each chapter could serve as the book's embeddings. +#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)] +pub struct DocumentEmbeddings { + #[serde(rename = "_id")] + pub id: String, + pub document: serde_json::Value, + pub embeddings: Vec, +} -use super::{ - embeddable::{Embeddable, EmbeddingKind, ManyEmbedding, SingleEmbedding}, - embedding::{Embedding, EmbeddingError, EmbeddingModel}, -}; +type Embeddings = Vec; -/// Builder for creating a collection of embeddings. -pub struct EmbeddingsBuilder { - kind: PhantomData, +/// Builder for creating a collection of embeddings +pub struct EmbeddingsBuilder { model: M, - documents: Vec<(D, Vec)>, + documents: Vec<(String, serde_json::Value, Vec)>, } -impl, K: EmbeddingKind> EmbeddingsBuilder { +impl EmbeddingsBuilder { /// Create a new embedding builder with the given embedding model pub fn new(model: M) -> Self { Self { - kind: PhantomData, model, documents: vec![], } } - /// Add a document that implements `Embeddable` to the builder. - pub fn document(mut self, document: D) -> Result { - let embed_targets = document.embeddable()?; + /// Add a simple document to the embedding collection. + /// The provided document string will be used for the embedding. + pub fn simple_document(mut self, id: &str, document: &str) -> Self { + self.documents.push(( + id.to_string(), + serde_json::Value::String(document.to_string()), + vec![document.to_string()], + )); + self + } - self.documents.push((document, embed_targets)); - Ok(self) + /// Add multiple documents to the embedding collection. + /// Each element of the vector is a tuple of the form (id, document). + pub fn simple_documents(mut self, documents: Vec<(String, String)>) -> Self { + self.documents + .extend(documents.into_iter().map(|(id, document)| { + ( + id, + serde_json::Value::String(document.clone()), + vec![document], + ) + })); + self } - /// Add many documents that implement `Embeddable` to the builder. - pub fn documents(mut self, documents: Vec) -> Result { - for doc in documents.into_iter() { - let embed_targets = doc.embeddable()?; + /// Add a tool to the embedding collection. + /// The `tool.context()` corresponds to the document being stored while + /// `tool.embedding_docs()` corresponds to the documents that will be used to generate the embeddings. + pub fn tool(mut self, tool: impl ToolEmbedding + 'static) -> Result { + self.documents.push(( + tool.name(), + serde_json::to_value(tool.context())?, + tool.embedding_docs(), + )); + Ok(self) + } - self.documents.push((doc, embed_targets)); + /// Add the tools from the given toolset to the embedding collection. + pub fn tools(mut self, toolset: &ToolSet) -> Result { + for (name, tool) in toolset.tools.iter() { + if let ToolType::Embedding(tool) = tool { + self.documents.push(( + name.clone(), + tool.context().map_err(|e| { + EmbeddingError::DocumentError(format!( + "Failed to generate context for tool {}: {}", + name, e + )) + })?, + tool.embedding_docs(), + )); + } } - Ok(self) } -} -impl - EmbeddingsBuilder -{ - /// Generate embeddings for all documents in the builder. - /// The method only applies when documents in the builder each contain multiple embedding targets. - /// Returns a vector of tuples, where the first element is the document and the second element is the vector of embeddings. - pub async fn build(&self) -> Result)>, EmbeddingError> { - // Use this for reference later to merge a document back with its embeddings. + /// Add a document to the embedding collection. + /// `embed_documents` are the documents that will be used to generate the embeddings + /// for `document`. + pub fn document( + mut self, + id: &str, + document: T, + embed_documents: Vec, + ) -> Self { + self.documents.push(( + id.to_string(), + serde_json::to_value(document).expect("Document should serialize"), + embed_documents, + )); + self + } + + /// Add multiple documents to the embedding collection. + /// Each element of the vector is a tuple of the form (id, document, embed_documents). + pub fn documents(mut self, documents: Vec<(String, T, Vec)>) -> Self { + self.documents.extend( + documents + .into_iter() + .map(|(id, document, embed_documents)| { + ( + id, + serde_json::to_value(document).expect("Document should serialize"), + embed_documents, + ) + }), + ); + self + } + + /// Add a json document to the embedding collection. + pub fn json_document( + mut self, + id: &str, + document: serde_json::Value, + embed_documents: Vec, + ) -> Self { + self.documents + .push((id.to_string(), document, embed_documents)); + self + } + + /// Add multiple json documents to the embedding collection. + pub fn json_documents( + mut self, + documents: Vec<(String, serde_json::Value, Vec)>, + ) -> Self { + self.documents.extend(documents); + self + } + + /// Generate the embeddings for the given documents + pub async fn build(self) -> Result { + // Create a temporary store for the documents let documents_map = self .documents - .clone() .into_iter() - .enumerate() - .map(|(id, (document, _))| (id, document)) + .map(|(id, document, docs)| (id, (document, docs))) .collect::>(); - let embeddings = stream::iter(self.documents.iter().enumerate()) - // Merge the embedding targets of each document into a single list of embedding targets. - .flat_map(|(i, (_, embed_targets))| { - stream::iter(embed_targets.iter().map(move |target| (i, target.clone()))) + let embeddings = stream::iter(documents_map.iter()) + // Flatten the documents + .flat_map(|(id, (_, docs))| { + stream::iter(docs.iter().map(|doc| (id.clone(), doc.clone()))) }) - // Chunk them into N (the emebdding API limit per request). + // Chunk them into N (the emebdding API limit per request) .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings for a chunk at a time. + // Generate the embeddings .map(|docs| async { - let (document_indices, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - + let (ids, docs): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); Ok::<_, EmbeddingError>( - document_indices - .into_iter() - .zip(self.model.embed_documents(embed_targets).await?.into_iter()) + ids.into_iter() + .zip(self.model.embed_documents(docs).await?.into_iter()) .collect::>(), ) }) .boxed() // Parallelize the embeddings generation over 10 concurrent requests .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - .try_fold( - HashMap::new(), - |mut acc: HashMap<_, Vec<_>>, embeddings| async move { - embeddings.into_iter().for_each(|(i, embedding)| { - acc.entry(i).or_default().push(embedding); + .try_fold(vec![], |mut acc, mut embeddings| async move { + Ok({ + acc.append(&mut embeddings); + acc + }) + }) + .await?; + + // Assemble the DocumentEmbeddings + let mut document_embeddings: HashMap = HashMap::new(); + embeddings.into_iter().for_each(|(id, embedding)| { + let (document, _) = documents_map.get(&id).expect("Document not found"); + let document_embedding = + document_embeddings + .entry(id.clone()) + .or_insert_with(|| DocumentEmbeddings { + id: id.clone(), + document: document.clone(), + embeddings: vec![], }); - Ok(acc) - }, - ) - .await? - .iter() - .fold(vec![], |mut acc, (i, embeddings_vec)| { - acc.push(( - documents_map.get(i).cloned().unwrap(), - embeddings_vec.clone(), - )); - acc - }); + document_embedding.embeddings.push(embedding); + }); - Ok(embeddings) + Ok(document_embeddings.into_values().collect()) } -} - -impl - EmbeddingsBuilder -{ - /// Generate embeddings for all documents in the builder. - /// The method only applies when documents in the builder each contain a single embedding target. - /// Returns a vector of tuples, where the first element is the document and the second element is the embedding. - pub async fn build(&self) -> Result, EmbeddingError> { - let embeddings = stream::iter( - self.documents - .clone() - .into_iter() - .map(|(document, embed_target)| (document, embed_target.first().cloned().unwrap())), - ) - // Chunk them into N (the emebdding API limit per request) - .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings - .map(|docs| async { - let (documents, embed_targets): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - Ok::<_, EmbeddingError>( - documents - .into_iter() - .zip(self.model.embed_documents(embed_targets).await?.into_iter()) - .collect::>(), - ) - }) - .boxed() - // Parallelize the embeddings generation over 10 concurrent requests - .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - .try_fold(vec![], |mut acc, embeddings| async move { - acc.extend(embeddings); - Ok(acc) - }) - .await?; - - Ok(embeddings) - } -} +} \ No newline at end of file diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index d590b7d0..061066c1 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -5,5 +5,4 @@ pub mod builder; pub mod embeddable; -pub mod embedding; -pub mod tool; +pub mod embedding; \ No newline at end of file diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs deleted file mode 100644 index 5e8ecd35..00000000 --- a/rig-core/src/embeddings/tool.rs +++ /dev/null @@ -1,25 +0,0 @@ -use crate::{self as rig, tool::ToolEmbeddingDyn}; -use rig::embeddings::embeddable::Embeddable; -use rig_derive::Embeddable; -use serde::Serialize; - -use super::embeddable::EmbeddableError; - -/// Used by EmbeddingsBuilder to embed anything that implements ToolEmbedding. -#[derive(Embeddable, Clone, Serialize, Default, Eq, PartialEq)] -pub struct EmbeddableTool { - name: String, - #[embed] - context: serde_json::Value, -} - -impl EmbeddableTool { - /// Convert item that implements ToolEmbedding to an EmbeddableTool. - pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result { - Ok(EmbeddableTool { - name: tool.name(), - context: serde_json::to_value(tool.context().map_err(EmbeddableError::SerdeError)?) - .map_err(EmbeddableError::SerdeError)?, - }) - } -} diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index fa798205..eba4ce3d 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -13,9 +13,7 @@ use std::collections::HashMap; use crate::{ agent::AgentBuilder, completion::{self, CompletionError}, - embeddings::{ - self, builder::EmbeddingsBuilder, embeddable::Embeddable, embedding::EmbeddingError, - }, + embeddings::{self, embedding::EmbeddingError, builder::EmbeddingsBuilder}, extractor::ExtractorBuilder, json_utils, }; @@ -87,11 +85,7 @@ impl Client { EmbeddingModel::new(self.clone(), model, input_type, ndims) } - pub fn embeddings( - &self, - model: &str, - input_type: &str, - ) -> EmbeddingsBuilder { + pub fn embeddings(&self, model: &str, input_type: &str) -> EmbeddingsBuilder { EmbeddingsBuilder::new(self.embedding_model(model, input_type)) } diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index d21bf8fb..eb373414 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -11,12 +11,7 @@ use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, - embeddings::{ - self, - builder::EmbeddingsBuilder, - embeddable::Embeddable, - embedding::{Embedding, EmbeddingError}, - }, + embeddings::{self, embedding::EmbeddingError}, extractor::ExtractorBuilder, json_utils, }; @@ -126,11 +121,8 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings( - &self, - model: &str, - ) -> EmbeddingsBuilder { - EmbeddingsBuilder::new(self.embedding_model(model)) + pub fn embeddings(&self, model: &str) -> embeddings::builder::EmbeddingsBuilder { + embeddings::builder::EmbeddingsBuilder::new(self.embedding_model(model)) } /// Create a completion model with the given name. @@ -250,7 +242,7 @@ impl embeddings::embedding::EmbeddingModel for EmbeddingModel { async fn embed_documents( &self, documents: Vec, - ) -> Result, EmbeddingError> { + ) -> Result, EmbeddingError> { let response = self .client .post("/v1/embeddings") @@ -276,7 +268,7 @@ impl embeddings::embedding::EmbeddingModel for EmbeddingModel { .data .into_iter() .zip(documents.into_iter()) - .map(|(embedding, document)| Embedding { + .map(|(embedding, document)| embeddings::embedding::Embedding { document, vec: embedding.embedding, }) diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index e92896b8..4e5e7675 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::{ completion::{self, ToolDefinition}, - embeddings::{embeddable::EmbeddableError, tool::EmbeddableTool}, + embeddings::{embeddable::EmbeddableError}, }; #[derive(Debug, thiserror::Error)] @@ -326,22 +326,6 @@ impl ToolSet { } Ok(docs) } - - /// Convert tools in self to objects of type EmbeddableTool. - /// This is necessary because when adding tools to the EmbeddingBuilder because all - /// documents added to the builder must all be of the same type. - pub fn embedabble_tools(&self) -> Result, EmbeddableError> { - self.tools - .values() - .filter_map(|tool_type| { - if let ToolType::Embedding(tool) = tool_type { - Some(EmbeddableTool::try_from(&**tool)) - } else { - None - } - }) - .collect::, _>>() - } } #[derive(Default)] diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index 6ee41d8d..031df2d3 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -6,7 +6,6 @@ edition = "2021" [dependencies] lancedb = "0.10.0" rig-core = { path = "../rig-core", version = "0.2.1" } -rig-derive = { path = "../rig-core/rig-core-derive" } arrow-array = "52.2.0" serde_json = "1.0.128" serde = "1.0.210" diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 415422f0..d95a42e4 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -2,39 +2,13 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; -use rig::embeddings::embedding::Embedding; -use rig::Embeddable; -use serde::Deserialize; - -#[derive(Embeddable, Clone, Deserialize, Debug)] -pub struct FakeDefinition { - pub id: String, - #[embed] - pub definition: String, -} - -pub fn fake_definitions() -> Vec { - vec![ - FakeDefinition { - id: "doc0".to_string(), - definition: "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.".to_string() - }, - FakeDefinition { - id: "doc1".to_string(), - definition: "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.".to_string() - }, - FakeDefinition { - id: "doc2".to_string(), - definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() - } - ] -} +use rig::embeddings::builder::DocumentEmbeddings; // Schema of table in LanceDB. pub fn schema(dims: usize) -> Schema { Schema::new(Fields::from(vec![ Field::new("id", DataType::Utf8, false), - Field::new("definition", DataType::Utf8, false), + Field::new("content", DataType::Utf8, false), Field::new( "embedding", DataType::FixedSizeList( @@ -46,36 +20,48 @@ pub fn schema(dims: usize) -> Schema { ])) } -// Convert FakeDefinition objects and their embedding to a RecordBatch. +// Convert DocumentEmbeddings objects to a RecordBatch. pub fn as_record_batch( - records: Vec<(FakeDefinition, Embedding)>, + records: Vec, dims: usize, ) -> Result { let id = StringArray::from_iter_values( records .iter() - .map(|(FakeDefinition { id, .. }, _)| id) + .flat_map(|record| (0..record.embeddings.len()).map(|i| format!("{}-{i}", record.id))) .collect::>(), ); - let definition = StringArray::from_iter_values( + let content = StringArray::from_iter_values( records .iter() - .map(|(FakeDefinition { definition, .. }, _)| definition) + .flat_map(|record| { + record + .embeddings + .iter() + .map(|embedding| embedding.document.clone()) + }) .collect::>(), ); let embedding = FixedSizeListArray::from_iter_primitive::( records .into_iter() - .map(|(_, Embedding { vec, .. })| Some(vec.into_iter().map(Some).collect::>())) + .flat_map(|record| { + record + .embeddings + .into_iter() + .map(|embedding| embedding.vec.into_iter().map(Some).collect::>()) + .map(Some) + .collect::>() + }) .collect::>(), dims as i32, ); RecordBatch::try_from_iter(vec![ ("id", Arc::new(id) as ArrayRef), - ("definition", Arc::new(definition) as ArrayRef), + ("content", Arc::new(content) as ArrayRef), ("embedding", Arc::new(embedding) as ArrayRef), ]) } diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 1b7870fb..a9744210 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,18 +1,25 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; +use fixture::{as_record_batch, schema}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{embedding::EmbeddingModel, builder::EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use serde::Deserialize; #[path = "./fixtures/lib.rs"] mod fixture; +#[derive(Deserialize, Debug)] +pub struct VectorSearchResult { + pub id: String, + pub content: String, +} + #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). @@ -25,18 +32,18 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; + // Set up test data for RAG demo + let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); + + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + let definitions = vec![definition; 256]; + // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions())? - // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. - .documents( - (0..256) - .map(|i| FakeDefinition { - id: format!("doc{}", i), - definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() - }) - .collect(), - )? + .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") + .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") + .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") + .simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect()) .build() .await?; @@ -65,7 +72,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n::("My boss says I zindle too much, what does that mean?", 1) + .top_n::("My boss says I zindle too much, what does that mean?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 630acc1a..69a8599b 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -1,9 +1,9 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema}; +use fixture::{as_record_batch, schema}; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{embedding::EmbeddingModel, builder::EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, }; @@ -21,9 +21,10 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions())? + .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") + .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") + .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") .build() .await?; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 8c10409b..9197fc42 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,18 +1,25 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, fake_definitions, schema, FakeDefinition}; +use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::{embedding::EmbeddingModel, builder::EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use serde::Deserialize; #[path = "./fixtures/lib.rs"] mod fixture; +#[derive(Deserialize, Debug)] +pub struct VectorSearchResult { + pub id: String, + pub content: String, +} + // Note: see docs to deploy LanceDB on other cloud providers such as google and azure. // https://lancedb.github.io/lancedb/guides/storage/ #[tokio::main] @@ -31,18 +38,18 @@ async fn main() -> Result<(), anyhow::Error> { .execute() .await?; + // Set up test data for RAG demo + let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); + + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + let definitions = vec![definition; 256]; + // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions())? - // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. - .documents( - (0..256) - .map(|i| FakeDefinition { - id: format!("doc{}", i), - definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() - }) - .collect(), - )? + .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") + .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") + .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") + .simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect()) .build() .await?; @@ -77,7 +84,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n::("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) + .top_n::("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-mongodb/Cargo.toml b/rig-mongodb/Cargo.toml index 78b48892..6f313838 100644 --- a/rig-mongodb/Cargo.toml +++ b/rig-mongodb/Cargo.toml @@ -13,8 +13,6 @@ repository = "https://github.com/0xPlaygrounds/rig" futures = "0.3.30" mongodb = "2.8.2" rig-core = { path = "../rig-core", version = "0.2.1" } -rig-derive = { path = "../rig-core/rig-core-derive" } - serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.117" tracing = "0.1.40" diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index a816c9c0..cca7a6d4 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,39 +1,13 @@ -use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; -use rig::providers::openai::TEXT_EMBEDDING_ADA_002; -use serde::{Deserialize, Serialize}; +use mongodb::{options::ClientOptions, Client as MongoClient, Collection}; use std::env; -use rig::Embeddable; use rig::{ - embeddings::builder::EmbeddingsBuilder, providers::openai::Client, + embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; -// Shape of data that needs to be RAG'ed. -// The definition field will be used to generate embeddings. -#[derive(Embeddable, Clone, Deserialize, Debug)] -struct FakeDefinition { - id: String, - #[embed] - definition: String, -} - -#[derive(Clone, Deserialize, Debug, Serialize)] -struct Link { - word: String, - link: String, -} - -// Shape of the document to be stored in MongoDB, with embeddings. -#[derive(Serialize, Debug)] -struct Document { - #[serde(rename = "_id")] - id: String, - definition: String, - embedding: Vec, -} - #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client @@ -51,61 +25,37 @@ async fn main() -> Result<(), anyhow::Error> { MongoClient::with_options(options).expect("MongoDB client options should be valid"); // Initialize MongoDB vector store - let collection: Collection = mongodb_client + let collection: Collection = mongodb_client .database("knowledgebase") .collection("context"); // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - let fake_definitions = vec![ - FakeDefinition { - id: "doc0".to_string(), - definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), - }, - FakeDefinition { - id: "doc1".to_string(), - definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - }, - FakeDefinition { - id: "doc2".to_string(), - definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), - } - ]; - let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(fake_definitions)? + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") .build() .await?; - let mongo_documents = embeddings - .iter() - .map( - |(FakeDefinition { id, definition, .. }, embedding)| Document { - id: id.clone(), - definition: definition.clone(), - embedding: embedding.vec.clone(), - }, - ) - .collect::>(); - - match collection.insert_many(mongo_documents, None).await { + match collection.insert_many(embeddings, None).await { Ok(_) => println!("Documents added successfully"), Err(e) => println!("Error adding documents: {:?}", e), - }; + } // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = MongoDbVectorStore::new(collection).index( - model, - "definitions_vector_index", - SearchParams::new("embedding"), - ); + let index = + MongoDbVectorStore::new(collection).index(model, "vector_index", SearchParams::default()); // Query the index let results = index - .top_n::("What is a linglingdong?", 1) - .await?; + .top_n::("What is a linglingdong?", 1) + .await? + .into_iter() + .map(|(score, id, doc)| (score, id, doc.document)) + .collect::>(); println!("Results: {:?}", results); diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 4778e454..17dda463 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -2,23 +2,23 @@ use futures::StreamExt; use mongodb::bson::{self, doc}; use rig::{ - embeddings::embedding::{Embedding, EmbeddingModel}, + embeddings::{builder::DocumentEmbeddings, embedding::Embedding, embedding::EmbeddingModel}, vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; /// A MongoDB vector store. -pub struct MongoDbVectorStore { - collection: mongodb::Collection, +pub struct MongoDbVectorStore { + collection: mongodb::Collection, } fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } -impl MongoDbVectorStore { +impl MongoDbVectorStore { /// Create a new `MongoDbVectorStore` from a MongoDB collection. - pub fn new(collection: mongodb::Collection) -> Self { + pub fn new(collection: mongodb::Collection) -> Self { Self { collection } } @@ -31,20 +31,20 @@ impl MongoDbVectorStore { model: M, index_name: &str, search_params: SearchParams, - ) -> MongoDbVectorIndex { + ) -> MongoDbVectorIndex { MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params) } } /// A vector index for a MongoDB collection. -pub struct MongoDbVectorIndex { - collection: mongodb::Collection, +pub struct MongoDbVectorIndex { + collection: mongodb::Collection, model: M, index_name: String, search_params: SearchParams, } -impl MongoDbVectorIndex { +impl MongoDbVectorIndex { /// Vector search stage of aggregation pipeline of mongoDB collection. /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex. fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document { @@ -52,13 +52,12 @@ impl MongoDbVectorIndex { filter, exact, num_candidates, - path, } = &self.search_params; doc! { "$vectorSearch": { "index": &self.index_name, - "path": path, + "path": "embeddings.vec", "queryVector": &prompt_embedding.vec, "numCandidates": num_candidates.unwrap_or((n * 10) as u32), "limit": n as u32, @@ -79,9 +78,9 @@ impl MongoDbVectorIndex { } } -impl MongoDbVectorIndex { +impl MongoDbVectorIndex { pub fn new( - collection: mongodb::Collection, + collection: mongodb::Collection, model: M, index_name: &str, search_params: SearchParams, @@ -99,19 +98,17 @@ impl MongoDbVectorIndex { /// on each of the fields pub struct SearchParams { filter: mongodb::bson::Document, - path: String, exact: Option, num_candidates: Option, } impl SearchParams { /// Initializes a new `SearchParams` with default values. - pub fn new(path: &str) -> Self { + pub fn new() -> Self { Self { filter: doc! {}, exact: None, num_candidates: None, - path: path.to_string(), } } @@ -141,9 +138,13 @@ impl SearchParams { } } -impl VectorStoreIndex - for MongoDbVectorIndex -{ +impl Default for SearchParams { + fn default() -> Self { + Self::new() + } +} + +impl VectorStoreIndex for MongoDbVectorIndex { async fn top_n Deserialize<'a> + std::marker::Send>( &self, query: &str, From b5e1bf3a505d0ecde421ef9bef72bc499e800a35 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 16 Oct 2024 14:32:21 -0400 Subject: [PATCH 25/47] cleanup: revert certain changes during branch split --- rig-core/src/embeddings/builder.rs | 4 ++-- rig-core/src/embeddings/mod.rs | 2 +- rig-core/src/providers/cohere.rs | 2 +- rig-core/src/providers/openai.rs | 5 ++++- rig-core/src/tool.rs | 5 +---- rig-lancedb/examples/vector_search_local_ann.rs | 2 +- rig-lancedb/examples/vector_search_local_enn.rs | 2 +- rig-lancedb/examples/vector_search_s3_ann.rs | 2 +- 8 files changed, 12 insertions(+), 12 deletions(-) diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 50e56537..afbe2056 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -31,7 +31,7 @@ use serde::{Deserialize, Serialize}; use crate::tool::{ToolEmbedding, ToolSet, ToolType}; -use super::embedding::{ Embedding, EmbeddingError, EmbeddingModel}; +use super::embedding::{Embedding, EmbeddingError, EmbeddingModel}; /// Struct that holds a document and its embeddings. /// @@ -232,4 +232,4 @@ impl EmbeddingsBuilder { Ok(document_embeddings.into_values().collect()) } -} \ No newline at end of file +} diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 061066c1..37e720cb 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -5,4 +5,4 @@ pub mod builder; pub mod embeddable; -pub mod embedding; \ No newline at end of file +pub mod embedding; diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index eba4ce3d..87f2334d 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -13,7 +13,7 @@ use std::collections::HashMap; use crate::{ agent::AgentBuilder, completion::{self, CompletionError}, - embeddings::{self, embedding::EmbeddingError, builder::EmbeddingsBuilder}, + embeddings::{self, builder::EmbeddingsBuilder, embedding::EmbeddingError}, extractor::ExtractorBuilder, json_utils, }; diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index eb373414..9adf6680 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -121,7 +121,10 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings(&self, model: &str) -> embeddings::builder::EmbeddingsBuilder { + pub fn embeddings( + &self, + model: &str, + ) -> embeddings::builder::EmbeddingsBuilder { embeddings::builder::EmbeddingsBuilder::new(self.embedding_model(model)) } diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 4e5e7675..98394ecf 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -3,10 +3,7 @@ use std::{collections::HashMap, pin::Pin}; use futures::Future; use serde::{Deserialize, Serialize}; -use crate::{ - completion::{self, ToolDefinition}, - embeddings::{embeddable::EmbeddableError}, -}; +use crate::completion::{self, ToolDefinition}; #[derive(Debug, thiserror::Error)] pub enum ToolError { diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index a9744210..52ee213a 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -5,7 +5,7 @@ use fixture::{as_record_batch, schema}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::{embedding::EmbeddingModel, builder::EmbeddingsBuilder}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 69a8599b..9f0ec934 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -3,7 +3,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use rig::{ - embeddings::{embedding::EmbeddingModel, builder::EmbeddingsBuilder}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, }; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 9197fc42..7ce61309 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,7 +4,7 @@ use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::{embedding::EmbeddingModel, builder::EmbeddingsBuilder}, + embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; From 7caf134db193a9890c612f6bf61d08c6b6512cf5 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 16 Oct 2024 14:37:44 -0400 Subject: [PATCH 26/47] docs: revert doc string --- rig-core/src/embeddings/builder.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index afbe2056..2458b98d 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -1,5 +1,6 @@ -//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded and generates the embeddings for each object when built. -//! Only types that implement the [Embeddable] trait can be added to the [EmbeddingsBuilder]. +//! The module provides an implementation of the [EmbeddingsBuilder] +//! struct, which allows users to build collections of document embeddings using different embedding +//! models and document sources. //! //! # Example //! ```rust From fb979eccbe751fa82b94fa5b486f0677010bc1e3 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 16 Oct 2024 17:59:08 -0400 Subject: [PATCH 27/47] refactor: use OneOrMany in Embbedable trait, make derive macro crate feature flag --- rig-core/Cargo.toml | 11 +- rig-core/rig-core-derive/src/embeddable.rs | 65 +---- rig-core/src/embeddings/embeddable.rs | 281 +++++++-------------- rig-core/src/lib.rs | 2 + rig-core/tests/embeddable_macro.rs | 141 +++++++++++ 5 files changed, 261 insertions(+), 239 deletions(-) create mode 100644 rig-core/tests/embeddable_macro.rs diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 658faf4b..d2ab06c6 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -23,9 +23,16 @@ futures = "0.3.29" ordered-float = "4.2.0" schemars = "0.8.16" thiserror = "1.0.61" -rig-derive = { path = "./rig-core-derive" } +rig-derive = { path = "./rig-core-derive", optional = true } [dev-dependencies] anyhow = "1.0.75" tokio = { version = "1.34.0", features = ["full"] } -tracing-subscriber = "0.3.18" \ No newline at end of file +tracing-subscriber = "0.3.18" + +[features] +rig_derive = ["dep:rig-derive"] + +[[test]] +name = "embeddable_macro" +required-features = ["rig_derive"] \ No newline at end of file diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index 7c3c883e..e36a7eea 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -1,21 +1,18 @@ use proc_macro2::TokenStream; use quote::quote; -use syn::{parse_str, DataStruct}; +use syn::DataStruct; use crate::{ basic::{add_struct_bounds, basic_embed_fields}, custom::custom_embed_fields, }; -const VEC_TYPE: &str = "Vec"; -const MANY_EMBEDDING: &str = "ManyEmbedding"; -const SINGLE_EMBEDDING: &str = "SingleEmbedding"; pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result { let name = &input.ident; let data = &input.data; let generics = &mut input.generics; - let (target_stream, embed_kind) = match data { + let target_stream = match data { syn::Data::Struct(data_struct) => { let (basic_targets, basic_target_size) = data_struct.basic(generics); let (custom_targets, custom_target_size) = data_struct.custom()?; @@ -30,13 +27,10 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu } // Determine whether the Embeddable::Kind should be SingleEmbedding or ManyEmbedding - ( - quote! { - let mut embed_targets = #basic_targets; - embed_targets.extend(#custom_targets) - }, - embed_kind(data_struct)?, - ) + quote! { + let mut embed_targets = #basic_targets; + embed_targets.extend(#custom_targets) + } } _ => { return Err(syn::Error::new_spanned( @@ -52,19 +46,16 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu // Note: Embeddable trait is imported with the macro. impl #impl_generics Embeddable for #name #ty_generics #where_clause { - type Kind = rig::embeddings::embeddable::#embed_kind; type Error = rig::embeddings::embeddable::EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { + fn embeddable(&self) -> Result, Self::Error> { #target_stream; - let targets = embed_targets.into_iter() - .collect::, _>>()? - .into_iter() - .flatten() - .collect::>(); - - Ok(targets) + rig::embeddings::embeddable::OneOrMany::try_from( + embed_targets.into_iter() + .collect::, _>>() + .map_err(|e| rig::embeddings::embeddable::EmbeddableError::Error(e.to_string()))? + ) } } }; @@ -72,38 +63,6 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu Ok(gen) } -/// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is 1, -/// returns the kind of embedding that field should be. -/// If the total number of fields tagged with #[embed] or #[embed(embed_with = "...")] is greater than 1, -/// return ManyEmbedding. -fn embed_kind(data_struct: &DataStruct) -> syn::Result { - fn embed_kind(field: &syn::Field) -> syn::Result { - match &field.ty { - syn::Type::Path(path) => { - if path.path.segments.first().unwrap().ident == VEC_TYPE { - parse_str(MANY_EMBEDDING) - } else { - parse_str(SINGLE_EMBEDDING) - } - } - _ => parse_str(SINGLE_EMBEDDING), - } - } - let fields = basic_embed_fields(data_struct) - .chain( - custom_embed_fields(data_struct)? - .into_iter() - .map(|(f, _)| f), - ) - .collect::>(); - - if fields.len() == 1 { - fields.iter().map(embed_kind).next().unwrap() - } else { - parse_str(MANY_EMBEDDING) - } -} - trait StructParser { // Handles fields tagged with #[embed] fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize); diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index c7b7bf0c..c4cf90da 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -35,6 +35,8 @@ impl EmbeddingKind for ManyEmbedding {} pub enum EmbeddableError { #[error("SerdeError: {0}")] SerdeError(#[from] serde_json::Error), + #[error("Error: {0}")] + Error(String), } /// Trait for types that can be embedded. @@ -42,266 +44,177 @@ pub enum EmbeddableError { /// If the type `Kind` is `SingleEmbedding`, the list of strings contains a single item, otherwise, the list can contain many items. /// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. pub trait Embeddable { - type Kind: EmbeddingKind; type Error: std::error::Error; - fn embeddable(&self) -> Result, Self::Error>; + fn embeddable(&self) -> Result, Self::Error>; +} + +#[derive(PartialEq, Eq, Debug)] +pub struct OneOrMany { + first: T, + rest: Vec, +} + +impl OneOrMany { + pub fn first(&self) -> T { + self.first.clone() + } + + pub fn rest(&self) -> Vec { + self.rest.clone() + } + + pub fn all(&self) -> Vec { + let mut all = vec![self.first.clone()]; + all.extend(self.rest.clone().into_iter()); + all + } +} + +impl From for OneOrMany { + fn from(item: T) -> Self { + OneOrMany { + first: item, + rest: vec![], + } + } +} + +impl TryFrom> for OneOrMany { + type Error = EmbeddableError; + + fn try_from(items: Vec) -> Result { + let mut iter = items.into_iter(); + Ok(OneOrMany { + first: match iter.next() { + Some(item) => item, + None => { + return Err(EmbeddableError::Error(format!( + "Cannot convert empty Vec to OneOrMany" + ))) + } + }, + rest: iter.collect(), + }) + } +} + +impl TryFrom>> for OneOrMany { + type Error = EmbeddableError; + + fn try_from(value: Vec>) -> Result { + let items = value + .into_iter() + .flat_map(|one_or_many| one_or_many.all()) + .collect::>(); + + OneOrMany::try_from(items) + } } ////////////////////////////////////////////////////// /// Implementations of Embeddable for common types /// ////////////////////////////////////////////////////// impl Embeddable for String { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { - Ok(vec![self.clone()]) + fn embeddable(&self) -> Result, Self::Error> { + Ok(OneOrMany::from(self.clone())) } } impl Embeddable for i8 { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for i16 { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for i32 { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for i64 { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for i128 { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for f32 { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for f64 { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for bool { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for char { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { - Ok(vec![self.to_string()]) + fn embeddable(&self) -> Result, Self::Error> { + Ok(OneOrMany::from(self.to_string())) } } impl Embeddable for serde_json::Value { - type Kind = SingleEmbedding; type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { - Ok(vec![ - serde_json::to_string(self).map_err(EmbeddableError::SerdeError)? - ]) + fn embeddable(&self) -> Result, Self::Error> { + Ok(OneOrMany::from( + serde_json::to_string(self).map_err(EmbeddableError::SerdeError)?, + )) } } impl Embeddable for Vec { - type Kind = ManyEmbedding; - type Error = T::Error; + type Error = EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { - Ok(self + fn embeddable(&self) -> Result, Self::Error> { + let items = self .iter() - .map(|i| i.embeddable()) - .collect::, _>>()? - .into_iter() - .flatten() - .collect()) - } -} - -#[cfg(test)] -mod tests { - use crate as rig; - use rig::embeddings::embeddable::{Embeddable, EmbeddableError}; - use rig_derive::Embeddable; - use serde::Serialize; - - fn serialize(definition: Definition) -> Result, EmbeddableError> { - Ok(vec![ - serde_json::to_string(&definition).map_err(EmbeddableError::SerdeError)? - ]) - } - - #[derive(Embeddable)] - struct FakeDefinition { - id: String, - word: String, - #[embed(embed_with = "serialize")] - definition: Definition, - } - - #[derive(Serialize, Clone)] - struct Definition { - word: String, - link: String, - speech: String, - } - - #[test] - fn test_custom_embed() { - let fake_definition = FakeDefinition { - id: "doc1".to_string(), - word: "house".to_string(), - definition: Definition { - speech: "noun".to_string(), - word: "a building in which people live; residence for human beings.".to_string(), - link: "https://www.dictionary.com/browse/house".to_string(), - }, - }; - - println!( - "FakeDefinition: {}, {}", - fake_definition.id, fake_definition.word - ); - - assert_eq!( - fake_definition.embeddable().unwrap(), - vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] - ) - } - - #[derive(Embeddable)] - struct FakeDefinition2 { - id: String, - word: String, - #[embed] - definition: String, - } - - #[test] - fn test_single_embed() { - let fake_definition = FakeDefinition2 { - id: "doc1".to_string(), - word: "house".to_string(), - definition: "a building in which people live; residence for human beings.".to_string(), - }; - - println!( - "FakeDefinition2: {}, {}", - fake_definition.id, fake_definition.word - ); - - assert_eq!( - fake_definition.embeddable().unwrap(), - vec!["a building in which people live; residence for human beings.".to_string()] - ) - } - - #[derive(Embeddable)] - struct Company { - id: String, - company: String, - #[embed] - employee_ages: Vec, - } - - #[test] - fn test_multiple_embed() { - let company = Company { - id: "doc1".to_string(), - company: "Google".to_string(), - employee_ages: vec![25, 30, 35, 40], - }; - - println!("Company: {}, {}", company.id, company.company); - - assert_eq!( - company.embeddable().unwrap(), - vec![ - "25".to_string(), - "30".to_string(), - "35".to_string(), - "40".to_string() - ] - ); - } - - #[derive(Embeddable)] - struct Company2 { - id: String, - #[embed] - company: String, - #[embed] - employee_ages: Vec, - } + .map(|item| item.embeddable()) + .collect::, _>>() + .map_err(|e| EmbeddableError::Error(e.to_string()))?; - #[test] - fn test_many_embed() { - let company = Company2 { - id: "doc1".to_string(), - company: "Google".to_string(), - employee_ages: vec![25, 30, 35, 40], - }; - - println!("Company2: {}", company.id); - - assert_eq!( - company.embeddable().unwrap(), - vec![ - "Google".to_string(), - "25".to_string(), - "30".to_string(), - "35".to_string(), - "40".to_string() - ] - ); + OneOrMany::try_from(items) } } diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index b7f0615e..cc17efa3 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -78,4 +78,6 @@ pub mod vector_store; // Export Embeddable trait and Embeddable together. pub use embeddings::embeddable::Embeddable; + +#[cfg(feature = "rig_derive")] pub use rig_derive::Embeddable; diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/embeddable_macro.rs new file mode 100644 index 00000000..130da109 --- /dev/null +++ b/rig-core/tests/embeddable_macro.rs @@ -0,0 +1,141 @@ +use rig::embeddings::embeddable::{EmbeddableError, OneOrMany}; +use rig::Embeddable; +use serde::Serialize; + +fn serialize(definition: Definition) -> Result, EmbeddableError> { + Ok(OneOrMany::from( + serde_json::to_string(&definition).map_err(EmbeddableError::SerdeError)?, + )) +} + +#[derive(Embeddable)] +struct FakeDefinition { + id: String, + word: String, + #[embed(embed_with = "serialize")] + definition: Definition, +} + +#[derive(Serialize, Clone)] +struct Definition { + word: String, + link: String, + speech: String, +} + +#[test] +fn test_custom_embed() { + let fake_definition = FakeDefinition { + id: "doc1".to_string(), + word: "house".to_string(), + definition: Definition { + speech: "noun".to_string(), + word: "a building in which people live; residence for human beings.".to_string(), + link: "https://www.dictionary.com/browse/house".to_string(), + }, + }; + + println!( + "FakeDefinition: {}, {}", + fake_definition.id, fake_definition.word + ); + + assert_eq!( + fake_definition.embeddable().unwrap(), + OneOrMany::from( + "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() + ) + + ) +} + +#[derive(Embeddable)] +struct FakeDefinition2 { + id: String, + word: String, + #[embed] + definition: String, +} + +#[test] +fn test_single_embed() { + let definition = "a building in which people live; residence for human beings.".to_string(); + + let fake_definition = FakeDefinition2 { + id: "doc1".to_string(), + word: "house".to_string(), + definition: definition.clone(), + }; + + println!( + "FakeDefinition2: {}, {}", + fake_definition.id, fake_definition.word + ); + + assert_eq!( + fake_definition.embeddable().unwrap(), + OneOrMany::from(definition) + ) +} + +#[derive(Embeddable)] +struct Company { + id: String, + company: String, + #[embed] + employee_ages: Vec, +} + +#[test] +fn test_multiple_embed() { + let company = Company { + id: "doc1".to_string(), + company: "Google".to_string(), + employee_ages: vec![25, 30, 35, 40], + }; + + println!("Company: {}, {}", company.id, company.company); + + assert_eq!( + company.embeddable().unwrap(), + OneOrMany::try_from(vec![ + "25".to_string(), + "30".to_string(), + "35".to_string(), + "40".to_string() + ]) + .unwrap() + ); +} + +#[derive(Embeddable)] +struct Company2 { + id: String, + #[embed] + company: String, + #[embed] + employee_ages: Vec, +} + +#[test] +fn test_many_embed() { + let company = Company2 { + id: "doc1".to_string(), + company: "Google".to_string(), + employee_ages: vec![25, 30, 35, 40], + }; + + println!("Company2: {}", company.id); + + assert_eq!( + company.embeddable().unwrap(), + OneOrMany::try_from(vec![ + "Google".to_string(), + "25".to_string(), + "30".to_string(), + "35".to_string(), + "40".to_string() + ]) + .unwrap() + ); +} From 690027c7e524653f50273ec6154c0eb5929df352 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 16 Oct 2024 18:10:50 -0400 Subject: [PATCH 28/47] tests: add some more tests --- rig-core/tests/embeddable_macro.rs | 56 +++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/embeddable_macro.rs index 130da109..bc076fed 100644 --- a/rig-core/tests/embeddable_macro.rs +++ b/rig-core/tests/embeddable_macro.rs @@ -51,6 +51,43 @@ fn test_custom_embed() { #[derive(Embeddable)] struct FakeDefinition2 { + id: String, + #[embed] + word: String, + #[embed(embed_with = "serialize")] + definition: Definition, +} + +#[test] +fn test_custom_and_basic_embed() { + let fake_definition = FakeDefinition2 { + id: "doc1".to_string(), + word: "house".to_string(), + definition: Definition { + speech: "noun".to_string(), + word: "a building in which people live; residence for human beings.".to_string(), + link: "https://www.dictionary.com/browse/house".to_string(), + }, + }; + + println!( + "FakeDefinition: {}, {}", + fake_definition.id, fake_definition.word + ); + + assert_eq!( + fake_definition.embeddable().unwrap().first(), + "house".to_string() + ); + + assert_eq!( + fake_definition.embeddable().unwrap().rest(), + vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] + ) +} + +#[derive(Embeddable)] +struct FakeDefinition3 { id: String, word: String, #[embed] @@ -61,14 +98,14 @@ struct FakeDefinition2 { fn test_single_embed() { let definition = "a building in which people live; residence for human beings.".to_string(); - let fake_definition = FakeDefinition2 { + let fake_definition = FakeDefinition3 { id: "doc1".to_string(), word: "house".to_string(), definition: definition.clone(), }; println!( - "FakeDefinition2: {}, {}", + "FakeDefinition3: {}, {}", fake_definition.id, fake_definition.word ); @@ -87,7 +124,7 @@ struct Company { } #[test] -fn test_multiple_embed() { +fn test_multiple_embed_strings() { let company = Company { id: "doc1".to_string(), company: "Google".to_string(), @@ -96,8 +133,10 @@ fn test_multiple_embed() { println!("Company: {}, {}", company.id, company.company); + let result = company.embeddable().unwrap(); + assert_eq!( - company.embeddable().unwrap(), + result, OneOrMany::try_from(vec![ "25".to_string(), "30".to_string(), @@ -106,6 +145,13 @@ fn test_multiple_embed() { ]) .unwrap() ); + + assert_eq!(result.first(), "25".to_string()); + + assert_eq!( + result.rest(), + vec!["30".to_string(), "35".to_string(), "40".to_string()] + ) } #[derive(Embeddable)] @@ -118,7 +164,7 @@ struct Company2 { } #[test] -fn test_many_embed() { +fn test_multiple_embed_tags() { let company = Company2 { id: "doc1".to_string(), company: "Google".to_string(), From cca60593b903e68c8843e42397dd233b51f9e2c7 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 17 Oct 2024 09:09:17 -0400 Subject: [PATCH 29/47] clippy: cargo clippy --- rig-core/src/embeddings/embeddable.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index c4cf90da..6084e07e 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -66,7 +66,7 @@ impl OneOrMany { pub fn all(&self) -> Vec { let mut all = vec![self.first.clone()]; - all.extend(self.rest.clone().into_iter()); + all.extend(self.rest.clone()); all } } @@ -89,9 +89,9 @@ impl TryFrom> for OneOrMany { first: match iter.next() { Some(item) => item, None => { - return Err(EmbeddableError::Error(format!( - "Cannot convert empty Vec to OneOrMany" - ))) + return Err(EmbeddableError::Error( + "Cannot convert empty Vec to OneOrMany".to_string(), + )) } }, rest: iter.collect(), From f785b8c3aa64c2fb240899a39c3ae374e72aa394 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 17 Oct 2024 09:56:50 -0400 Subject: [PATCH 30/47] docs: add docstring to oneormany --- rig-core/src/embeddings/embeddable.rs | 80 +++++++++++++-------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index 6084e07e..9e751b00 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -3,45 +3,41 @@ //! ```rust //! use std::env; //! -//! use rig::Embeddable; //! use serde::{Deserialize, Serialize}; //! -//! #[derive(Embeddable)] //! struct FakeDefinition { //! id: String, //! word: String, -//! #[embed] //! definitions: Vec, //! } //! -//! // Do something with FakeDefinition -//! // ... +//! let fake_definition = FakeDefinition { +//! id: "doc1".to_string(), +//! word: "hello".to_string(), +//! definition: "used as a greeting or to begin a conversation".to_string() +//! }; +//! +//! impl Embeddable for FakeDefinition { +//! type Error = anyhow::Error; +//! +//! fn embeddable(&self) -> Result, Self::Error> { +//! // Embeddigns only need to be generated for `definition` field. +//! // Select it from te struct and return it as a single item. +//! Ok(OneOrMany::from(self.definition.clone())) +//! } +//! } //! ``` -/// The associated type `Kind` on the trait `Embeddable` must implement this trait. -pub trait EmbeddingKind {} - -/// Used for structs that contain a single embedding target. -pub struct SingleEmbedding; -impl EmbeddingKind for SingleEmbedding {} - -/// Used for structs that contain many embedding targets. -pub struct ManyEmbedding; -impl EmbeddingKind for ManyEmbedding {} - /// Error type used for when the `embeddable` method fails. /// Used by default implementations of `Embeddable` for common types. #[derive(Debug, thiserror::Error)] pub enum EmbeddableError { #[error("SerdeError: {0}")] SerdeError(#[from] serde_json::Error), - #[error("Error: {0}")] - Error(String), } /// Trait for types that can be embedded. -/// The `embeddable` method returns a list of strings for which embeddings will be generated by the embeddings builder. -/// If the type `Kind` is `SingleEmbedding`, the list of strings contains a single item, otherwise, the list can contain many items. +/// The `embeddable` method returns a OneOrMany which contains strings for which embeddings will be generated by the embeddings builder. /// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. pub trait Embeddable { type Error: std::error::Error; @@ -49,21 +45,31 @@ pub trait Embeddable { fn embeddable(&self) -> Result, Self::Error>; } +/// Struct containing either a single item or a list of items of type T. +/// If a single item is present, `first` will contain it and `rest` will be empty. +/// If multiple items are present, `first` will contain the first item and `rest` will contain the rest. +/// IMPORTANT: this struct cannot be created with an empty vector. +/// OneOrMany objects can only be create using OneOrMany::from() or OneOrMany::try_from(). #[derive(PartialEq, Eq, Debug)] pub struct OneOrMany { + /// First item in the list. first: T, + /// Rest of the items in the list. rest: Vec, } impl OneOrMany { + /// Get the first item in the list. pub fn first(&self) -> T { self.first.clone() } + /// Get the rest of the items in the list (excluding the first one). pub fn rest(&self) -> Vec { self.rest.clone() } + /// Get all items in the list (joins the first with the rest). pub fn all(&self) -> Vec { let mut all = vec![self.first.clone()]; all.extend(self.rest.clone()); @@ -71,6 +77,7 @@ impl OneOrMany { } } +/// Create a OneOrMany object with a single item. impl From for OneOrMany { fn from(item: T) -> Self { OneOrMany { @@ -80,35 +87,29 @@ impl From for OneOrMany { } } -impl TryFrom> for OneOrMany { - type Error = EmbeddableError; - - fn try_from(items: Vec) -> Result { +/// Create a OneOrMany object with a list of items. +impl From> for OneOrMany { + fn from(items: Vec) -> Self { let mut iter = items.into_iter(); - Ok(OneOrMany { + OneOrMany { first: match iter.next() { Some(item) => item, - None => { - return Err(EmbeddableError::Error( - "Cannot convert empty Vec to OneOrMany".to_string(), - )) - } + None => panic!("Cannot create OneOrMany with an empty vector."), }, rest: iter.collect(), - }) + } } } -impl TryFrom>> for OneOrMany { - type Error = EmbeddableError; - - fn try_from(value: Vec>) -> Result { +/// Merge a list of OneOrMany items into a single OneOrMany item. +impl From>> for OneOrMany { + fn from(value: Vec>) -> Self { let items = value .into_iter() .flat_map(|one_or_many| one_or_many.all()) .collect::>(); - OneOrMany::try_from(items) + OneOrMany::from(items) } } @@ -206,15 +207,14 @@ impl Embeddable for serde_json::Value { } impl Embeddable for Vec { - type Error = EmbeddableError; + type Error = T::Error; fn embeddable(&self) -> Result, Self::Error> { let items = self .iter() .map(|item| item.embeddable()) - .collect::, _>>() - .map_err(|e| EmbeddableError::Error(e.to_string()))?; + .collect::, _>>()?; - OneOrMany::try_from(items) + Ok(OneOrMany::from(items)) } } From 0e2ade9b672baa174efc95cb509fe1c1288e74d3 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 17 Oct 2024 10:06:02 -0400 Subject: [PATCH 31/47] fix(macro): update error handling --- rig-core/rig-core-derive/src/embeddable.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index e36a7eea..563b8c3c 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -51,11 +51,10 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu fn embeddable(&self) -> Result, Self::Error> { #target_stream; - rig::embeddings::embeddable::OneOrMany::try_from( + Ok(rig::embeddings::embeddable::OneOrMany::from( embed_targets.into_iter() - .collect::, _>>() - .map_err(|e| rig::embeddings::embeddable::EmbeddableError::Error(e.to_string()))? - ) + .collect::, _>>()? + )) } } }; From a98769c2228fa7101f7e4ce8296fa7bec132b01a Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 17 Oct 2024 10:40:27 -0400 Subject: [PATCH 32/47] refactor: reexport EmbeddingsBuilder in rig and update imports --- rig-core/examples/calculator_chatbot.rs | 3 ++- rig-core/examples/rag.rs | 3 ++- rig-core/examples/rag_dynamic_tools.rs | 3 ++- rig-core/examples/vector_search.rs | 3 ++- rig-core/examples/vector_search_cohere.rs | 3 ++- rig-core/src/embeddings/embeddable.rs | 2 +- rig-core/src/lib.rs | 3 ++- rig-core/src/providers/cohere.rs | 4 ++-- rig-core/src/providers/openai.rs | 9 +++------ rig-lancedb/examples/vector_search_local_ann.rs | 3 ++- rig-lancedb/examples/vector_search_local_enn.rs | 3 ++- rig-lancedb/examples/vector_search_s3_ann.rs | 3 ++- rig-mongodb/examples/vector_search_mongodb.rs | 3 ++- 13 files changed, 26 insertions(+), 19 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 949073b3..02a029e9 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -2,10 +2,11 @@ use anyhow::Result; use rig::{ cli_chatbot::cli_chatbot, completion::ToolDefinition, - embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, + embeddings::builder::DocumentEmbeddings, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, + EmbeddingsBuilder, }; use serde::{Deserialize, Serialize}; use serde_json::json; diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 936a3d05..ec9bebd8 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -2,9 +2,10 @@ use std::env; use rig::{ completion::Prompt, - embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, + embeddings::builder::DocumentEmbeddings, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, + EmbeddingsBuilder, }; #[tokio::main] diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index f00543a2..dce27693 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -1,10 +1,11 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, - embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, + embeddings::builder::DocumentEmbeddings, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, + EmbeddingsBuilder, }; use serde::{Deserialize, Serialize}; use serde_json::json; diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 26692706..67d04424 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -1,9 +1,10 @@ use std::env; use rig::{ - embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, + embeddings::builder::DocumentEmbeddings, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, + EmbeddingsBuilder, }; #[tokio::main] diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index c14fe0ce..1be5de8d 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,9 +1,10 @@ use std::env; use rig::{ - embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, + embeddings::builder::DocumentEmbeddings, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, + EmbeddingsBuilder, }; #[tokio::main] diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index 9e751b00..aafe2a5c 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -49,7 +49,7 @@ pub trait Embeddable { /// If a single item is present, `first` will contain it and `rest` will be empty. /// If multiple items are present, `first` will contain the first item and `rest` will contain the rest. /// IMPORTANT: this struct cannot be created with an empty vector. -/// OneOrMany objects can only be create using OneOrMany::from() or OneOrMany::try_from(). +/// OneOrMany objects can only be created using OneOrMany::from() or OneOrMany::try_from(). #[derive(PartialEq, Eq, Debug)] pub struct OneOrMany { /// First item in the list. diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index cc17efa3..5c498c81 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -76,7 +76,8 @@ pub mod providers; pub mod tool; pub mod vector_store; -// Export Embeddable trait and Embeddable together. +// Re-export commonly used types and traits +pub use embeddings::builder::EmbeddingsBuilder; pub use embeddings::embeddable::Embeddable; #[cfg(feature = "rig_derive")] diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 87f2334d..844849d7 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -13,9 +13,9 @@ use std::collections::HashMap; use crate::{ agent::AgentBuilder, completion::{self, CompletionError}, - embeddings::{self, builder::EmbeddingsBuilder, embedding::EmbeddingError}, + embeddings::{self, embedding::EmbeddingError}, extractor::ExtractorBuilder, - json_utils, + json_utils, EmbeddingsBuilder, }; use schemars::JsonSchema; diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 9adf6680..6b95a56a 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -13,7 +13,7 @@ use crate::{ completion::{self, CompletionError, CompletionRequest}, embeddings::{self, embedding::EmbeddingError}, extractor::ExtractorBuilder, - json_utils, + json_utils, EmbeddingsBuilder, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -121,11 +121,8 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings( - &self, - model: &str, - ) -> embeddings::builder::EmbeddingsBuilder { - embeddings::builder::EmbeddingsBuilder::new(self.embedding_model(model)) + pub fn embeddings(&self, model: &str) -> EmbeddingsBuilder { + EmbeddingsBuilder::new(self.embedding_model(model)) } /// Create a completion model with the given name. diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 52ee213a..8f3ceee6 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -5,8 +5,9 @@ use fixture::{as_record_batch, schema}; use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::embedding::EmbeddingModel, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + EmbeddingsBuilder, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; use serde::Deserialize; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 9f0ec934..90a14fe0 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -3,9 +3,10 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::embedding::EmbeddingModel, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, + EmbeddingsBuilder, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 7ce61309..2d97ed80 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,9 +4,10 @@ use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::{builder::EmbeddingsBuilder, embedding::EmbeddingModel}, + embeddings::embedding::EmbeddingModel, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, + EmbeddingsBuilder, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; use serde::Deserialize; diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index cca7a6d4..517bfea6 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -2,9 +2,10 @@ use mongodb::{options::ClientOptions, Client as MongoClient, Collection}; use std::env; use rig::{ - embeddings::{builder::DocumentEmbeddings, builder::EmbeddingsBuilder}, + embeddings::builder::DocumentEmbeddings, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, + EmbeddingsBuilder, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; From 067894cc27eab07be72445baed03d8928f0eeb74 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 17 Oct 2024 12:31:28 -0400 Subject: [PATCH 33/47] feat: implement IntoIterator and Iterator for OneOrMany --- rig-core/src/embeddings/embeddable.rs | 199 ++++++++++++++++++++++---- 1 file changed, 169 insertions(+), 30 deletions(-) diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index aafe2a5c..f5f09260 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -49,8 +49,8 @@ pub trait Embeddable { /// If a single item is present, `first` will contain it and `rest` will be empty. /// If multiple items are present, `first` will contain the first item and `rest` will contain the rest. /// IMPORTANT: this struct cannot be created with an empty vector. -/// OneOrMany objects can only be created using OneOrMany::from() or OneOrMany::try_from(). -#[derive(PartialEq, Eq, Debug)] +/// OneOrMany objects can only be created using OneOrMany::from_single() or OneOrMany::from_many(). +#[derive(PartialEq, Eq, Debug, Clone)] pub struct OneOrMany { /// First item in the list. first: T, @@ -69,27 +69,16 @@ impl OneOrMany { self.rest.clone() } - /// Get all items in the list (joins the first with the rest). - pub fn all(&self) -> Vec { - let mut all = vec![self.first.clone()]; - all.extend(self.rest.clone()); - all - } -} - -/// Create a OneOrMany object with a single item. -impl From for OneOrMany { - fn from(item: T) -> Self { + /// Create a OneOrMany object with a single item of any type. + pub fn from_single(item: T) -> Self { OneOrMany { first: item, rest: vec![], } } -} -/// Create a OneOrMany object with a list of items. -impl From> for OneOrMany { - fn from(items: Vec) -> Self { + /// Create a OneOrMany object with a single item of any type. + pub fn from_many(items: Vec) -> Self { let mut iter = items.into_iter(); OneOrMany { first: match iter.next() { @@ -99,6 +88,74 @@ impl From> for OneOrMany { rest: iter.collect(), } } + + /// Use the Iterator trait on OneOrMany + pub fn iter(&self) -> OneOrManyIterator { + OneOrManyIterator { + one_or_many: self, + index: 0, + } + } +} + +/// Implement Iterator for OneOrMany. +/// Iterates over all items in both `first` and `rest`. +/// Borrows the OneOrMany object that is being iterator over. +pub struct OneOrManyIterator<'a, T> { + one_or_many: &'a OneOrMany, + index: usize, +} + +impl<'a, T> Iterator for OneOrManyIterator<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + let mut item = None; + if self.index == 0 { + item = Some(&self.one_or_many.first) + } else if self.index - 1 < self.one_or_many.rest.len() { + item = Some(&self.one_or_many.rest[self.index - 1]); + }; + + self.index += 1; + item + } +} + +/// Implement IntoIterator for OneOrMany. +/// Iterates over all items in both `first` and `rest`. +/// Takes ownership the OneOrMany object that is being iterator over. +pub struct OneOrManyIntoIterator { + one_or_many: OneOrMany, + index: usize, +} + +impl IntoIterator for OneOrMany { + type Item = T; + type IntoIter = OneOrManyIntoIterator; + + fn into_iter(self) -> OneOrManyIntoIterator { + OneOrManyIntoIterator { + one_or_many: self, + index: 0, + } + } +} + +impl Iterator for OneOrManyIntoIterator { + type Item = T; + + fn next(&mut self) -> Option { + let mut item = None; + if self.index == 0 { + item = Some(self.one_or_many.first()) + } else if self.index - 1 < self.one_or_many.rest.len() { + item = Some(self.one_or_many.rest[self.index - 1].clone()); + }; + + self.index += 1; + item + } } /// Merge a list of OneOrMany items into a single OneOrMany item. @@ -106,10 +163,10 @@ impl From>> for OneOrMany { fn from(value: Vec>) -> Self { let items = value .into_iter() - .flat_map(|one_or_many| one_or_many.all()) + .flat_map(|one_or_many| one_or_many.into_iter()) .collect::>(); - OneOrMany::from(items) + OneOrMany::from_many(items) } } @@ -120,7 +177,7 @@ impl Embeddable for String { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from(self.clone())) + Ok(OneOrMany::from_single(self.clone())) } } @@ -128,7 +185,7 @@ impl Embeddable for i8 { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -136,7 +193,7 @@ impl Embeddable for i16 { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -144,7 +201,7 @@ impl Embeddable for i32 { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -152,7 +209,7 @@ impl Embeddable for i64 { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -160,7 +217,7 @@ impl Embeddable for i128 { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -168,7 +225,7 @@ impl Embeddable for f32 { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -176,7 +233,7 @@ impl Embeddable for f64 { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -184,7 +241,7 @@ impl Embeddable for bool { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -192,7 +249,7 @@ impl Embeddable for char { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from(self.to_string())) + Ok(OneOrMany::from_single(self.to_string())) } } @@ -200,7 +257,7 @@ impl Embeddable for serde_json::Value { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from( + Ok(OneOrMany::from_single( serde_json::to_string(self).map_err(EmbeddableError::SerdeError)?, )) } @@ -218,3 +275,85 @@ impl Embeddable for Vec { Ok(OneOrMany::from(items)) } } + +#[cfg(test)] +mod test { + use super::OneOrMany; + + #[test] + fn test_one_or_many_iter_single() { + let one_or_many = OneOrMany::from_single("hello".to_string()); + + assert_eq!(one_or_many.iter().count(), 1); + + one_or_many.iter().for_each(|i| { + assert_eq!(i, "hello"); + }); + } + + #[test] + fn test_one_or_many_iter() { + let one_or_many = OneOrMany::from_many(vec!["hello".to_string(), "word".to_string()]); + + assert_eq!(one_or_many.iter().count(), 2); + + one_or_many.iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello"); + } + if i == 1 { + assert_eq!(item, "word"); + } + }); + } + + #[test] + fn test_one_or_many_into_iter_single() { + let one_or_many = OneOrMany::from_single("hello".to_string()); + + assert_eq!(one_or_many.clone().into_iter().count(), 1); + + one_or_many.into_iter().for_each(|i| { + assert_eq!(i, "hello".to_string()); + }); + } + + #[test] + fn test_one_or_many_into_iter() { + let one_or_many = OneOrMany::from_many(vec!["hello".to_string(), "word".to_string()]); + + assert_eq!(one_or_many.clone().into_iter().count(), 2); + + one_or_many.into_iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello".to_string()); + } + if i == 1 { + assert_eq!(item, "word".to_string()); + } + }); + } + + #[test] + fn test_one_or_many_merge() { + let one_or_many_1 = OneOrMany::from_many(vec!["hello".to_string(), "word".to_string()]); + + let one_or_many_2 = OneOrMany::from_single("sup".to_string()); + + let merged = OneOrMany::from(vec![one_or_many_1, one_or_many_2]); + + assert_eq!(merged.iter().count(), 3); + + merged.iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello"); + } + if i == 1 { + assert_eq!(item, "word"); + } + if i == 2 { + assert_eq!(item, "sup"); + } + }); + } +} From 32bcc61fa7af785e6ea4c5233a2734c0973ce6fd Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 17 Oct 2024 12:35:37 -0400 Subject: [PATCH 34/47] refactor: rename from methods --- rig-core/src/embeddings/embeddable.rs | 48 +++++++++++++-------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index f5f09260..a4bb85c5 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -22,8 +22,8 @@ //! //! fn embeddable(&self) -> Result, Self::Error> { //! // Embeddigns only need to be generated for `definition` field. -//! // Select it from te struct and return it as a single item. -//! Ok(OneOrMany::from(self.definition.clone())) +//! // Select it from the struct and return it as a single item. +//! Ok(OneOrMany::one(self.definition.clone())) //! } //! } //! ``` @@ -49,7 +49,7 @@ pub trait Embeddable { /// If a single item is present, `first` will contain it and `rest` will be empty. /// If multiple items are present, `first` will contain the first item and `rest` will contain the rest. /// IMPORTANT: this struct cannot be created with an empty vector. -/// OneOrMany objects can only be created using OneOrMany::from_single() or OneOrMany::from_many(). +/// OneOrMany objects can only be created using OneOrMany::one() or OneOrMany::many(). #[derive(PartialEq, Eq, Debug, Clone)] pub struct OneOrMany { /// First item in the list. @@ -70,15 +70,15 @@ impl OneOrMany { } /// Create a OneOrMany object with a single item of any type. - pub fn from_single(item: T) -> Self { + pub fn one(item: T) -> Self { OneOrMany { first: item, rest: vec![], } } - /// Create a OneOrMany object with a single item of any type. - pub fn from_many(items: Vec) -> Self { + /// Create a OneOrMany object with a vector of items of any type. + pub fn many(items: Vec) -> Self { let mut iter = items.into_iter(); OneOrMany { first: match iter.next() { @@ -166,7 +166,7 @@ impl From>> for OneOrMany { .flat_map(|one_or_many| one_or_many.into_iter()) .collect::>(); - OneOrMany::from_many(items) + OneOrMany::many(items) } } @@ -177,7 +177,7 @@ impl Embeddable for String { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from_single(self.clone())) + Ok(OneOrMany::one(self.clone())) } } @@ -185,7 +185,7 @@ impl Embeddable for i8 { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -193,7 +193,7 @@ impl Embeddable for i16 { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -201,7 +201,7 @@ impl Embeddable for i32 { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -209,7 +209,7 @@ impl Embeddable for i64 { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -217,7 +217,7 @@ impl Embeddable for i128 { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -225,7 +225,7 @@ impl Embeddable for f32 { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -233,7 +233,7 @@ impl Embeddable for f64 { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -241,7 +241,7 @@ impl Embeddable for bool { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -249,7 +249,7 @@ impl Embeddable for char { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from_single(self.to_string())) + Ok(OneOrMany::one(self.to_string())) } } @@ -257,7 +257,7 @@ impl Embeddable for serde_json::Value { type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { - Ok(OneOrMany::from_single( + Ok(OneOrMany::one( serde_json::to_string(self).map_err(EmbeddableError::SerdeError)?, )) } @@ -282,7 +282,7 @@ mod test { #[test] fn test_one_or_many_iter_single() { - let one_or_many = OneOrMany::from_single("hello".to_string()); + let one_or_many = OneOrMany::one("hello".to_string()); assert_eq!(one_or_many.iter().count(), 1); @@ -293,7 +293,7 @@ mod test { #[test] fn test_one_or_many_iter() { - let one_or_many = OneOrMany::from_many(vec!["hello".to_string(), "word".to_string()]); + let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]); assert_eq!(one_or_many.iter().count(), 2); @@ -309,7 +309,7 @@ mod test { #[test] fn test_one_or_many_into_iter_single() { - let one_or_many = OneOrMany::from_single("hello".to_string()); + let one_or_many = OneOrMany::one("hello".to_string()); assert_eq!(one_or_many.clone().into_iter().count(), 1); @@ -320,7 +320,7 @@ mod test { #[test] fn test_one_or_many_into_iter() { - let one_or_many = OneOrMany::from_many(vec!["hello".to_string(), "word".to_string()]); + let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]); assert_eq!(one_or_many.clone().into_iter().count(), 2); @@ -336,9 +336,9 @@ mod test { #[test] fn test_one_or_many_merge() { - let one_or_many_1 = OneOrMany::from_many(vec!["hello".to_string(), "word".to_string()]); + let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]); - let one_or_many_2 = OneOrMany::from_single("sup".to_string()); + let one_or_many_2 = OneOrMany::one("sup".to_string()); let merged = OneOrMany::from(vec![one_or_many_1, one_or_many_2]); From 564bef408a1d8c22d317015fe3b697839869dcd0 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 17 Oct 2024 12:50:47 -0400 Subject: [PATCH 35/47] tests: fix failing tests --- rig-core/tests/embeddable_macro.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/embeddable_macro.rs index bc076fed..96e42c7f 100644 --- a/rig-core/tests/embeddable_macro.rs +++ b/rig-core/tests/embeddable_macro.rs @@ -3,7 +3,7 @@ use rig::Embeddable; use serde::Serialize; fn serialize(definition: Definition) -> Result, EmbeddableError> { - Ok(OneOrMany::from( + Ok(OneOrMany::one( serde_json::to_string(&definition).map_err(EmbeddableError::SerdeError)?, )) } @@ -42,7 +42,7 @@ fn test_custom_embed() { assert_eq!( fake_definition.embeddable().unwrap(), - OneOrMany::from( + OneOrMany::one( "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() ) @@ -111,7 +111,7 @@ fn test_single_embed() { assert_eq!( fake_definition.embeddable().unwrap(), - OneOrMany::from(definition) + OneOrMany::one(definition) ) } @@ -137,13 +137,12 @@ fn test_multiple_embed_strings() { assert_eq!( result, - OneOrMany::try_from(vec![ + OneOrMany::many(vec![ "25".to_string(), "30".to_string(), "35".to_string(), "40".to_string() ]) - .unwrap() ); assert_eq!(result.first(), "25".to_string()); @@ -175,13 +174,12 @@ fn test_multiple_embed_tags() { assert_eq!( company.embeddable().unwrap(), - OneOrMany::try_from(vec![ + OneOrMany::many(vec![ "Google".to_string(), "25".to_string(), "30".to_string(), "35".to_string(), "40".to_string() ]) - .unwrap() ); } From 04f1f3e3f488d7a0c6b7f4445cb16ab223abe981 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 17 Oct 2024 14:39:23 -0400 Subject: [PATCH 36/47] refactor&fix: make PR review changes --- rig-core/Cargo.toml | 4 +- rig-core/examples/calculator_chatbot.rs | 2 +- rig-core/examples/rag.rs | 2 +- rig-core/examples/rag_dynamic_tools.rs | 2 +- rig-core/examples/vector_search.rs | 2 +- rig-core/examples/vector_search_cohere.rs | 2 +- rig-core/src/embeddings/builder.rs | 7 +- rig-core/src/embeddings/embeddable.rs | 238 ++---------------- rig-core/src/embeddings/embedding.rs | 2 +- rig-core/src/embeddings/mod.rs | 4 + rig-core/src/lib.rs | 4 +- rig-core/src/providers/cohere.rs | 10 +- rig-core/src/providers/openai.rs | 10 +- rig-core/src/vec_utils.rs | 191 ++++++++++++++ rig-core/src/vector_store/in_memory_store.rs | 2 +- rig-core/src/vector_store/mod.rs | 2 +- .../examples/vector_search_local_ann.rs | 5 +- .../examples/vector_search_local_enn.rs | 3 +- rig-lancedb/examples/vector_search_s3_ann.rs | 3 +- rig-mongodb/examples/vector_search_mongodb.rs | 2 +- rig-mongodb/src/lib.rs | 3 +- 21 files changed, 244 insertions(+), 256 deletions(-) create mode 100644 rig-core/src/vec_utils.rs diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index d2ab06c6..bc65409a 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -31,8 +31,8 @@ tokio = { version = "1.34.0", features = ["full"] } tracing-subscriber = "0.3.18" [features] -rig_derive = ["dep:rig-derive"] +derive = ["dep:rig-derive"] [[test]] name = "embeddable_macro" -required-features = ["rig_derive"] \ No newline at end of file +required-features = ["derive"] \ No newline at end of file diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 02a029e9..8b622482 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -3,10 +3,10 @@ use rig::{ cli_chatbot::cli_chatbot, completion::ToolDefinition, embeddings::builder::DocumentEmbeddings, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, - EmbeddingsBuilder, }; use serde::{Deserialize, Serialize}; use serde_json::json; diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index ec9bebd8..674d028c 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -3,9 +3,9 @@ use std::env; use rig::{ completion::Prompt, embeddings::builder::DocumentEmbeddings, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::in_memory_store::InMemoryVectorStore, - EmbeddingsBuilder, }; #[tokio::main] diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index dce27693..5476eb80 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -2,10 +2,10 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, embeddings::builder::DocumentEmbeddings, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::in_memory_store::InMemoryVectorStore, - EmbeddingsBuilder, }; use serde::{Deserialize, Serialize}; use serde_json::json; diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 67d04424..feadbfb4 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -2,9 +2,9 @@ use std::env; use rig::{ embeddings::builder::DocumentEmbeddings, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - EmbeddingsBuilder, }; #[tokio::main] diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 1be5de8d..6b93bcdd 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -2,9 +2,9 @@ use std::env; use rig::{ embeddings::builder::DocumentEmbeddings, + embeddings::EmbeddingsBuilder, providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, - EmbeddingsBuilder, }; #[tokio::main] diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs index 2458b98d..46eef092 100644 --- a/rig-core/src/embeddings/builder.rs +++ b/rig-core/src/embeddings/builder.rs @@ -30,9 +30,10 @@ use std::{cmp::max, collections::HashMap}; use futures::{stream, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; -use crate::tool::{ToolEmbedding, ToolSet, ToolType}; - -use super::embedding::{Embedding, EmbeddingError, EmbeddingModel}; +use crate::{ + embeddings::{Embedding, EmbeddingError, EmbeddingModel}, + tool::{ToolEmbedding, ToolSet, ToolType}, +}; /// Struct that holds a document and its embeddings. /// diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index a4bb85c5..e303ecda 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -1,5 +1,5 @@ //! The module defines the [Embeddable] trait, which must be implemented for types that can be embedded. -//! //! # Example +//! # Example //! ```rust //! use std::env; //! @@ -8,7 +8,7 @@ //! struct FakeDefinition { //! id: String, //! word: String, -//! definitions: Vec, +//! definition: String, //! } //! //! let fake_definition = FakeDefinition { @@ -28,151 +28,26 @@ //! } //! ``` +use crate::vec_utils::OneOrMany; + /// Error type used for when the `embeddable` method fails. /// Used by default implementations of `Embeddable` for common types. #[derive(Debug, thiserror::Error)] -pub enum EmbeddableError { - #[error("SerdeError: {0}")] - SerdeError(#[from] serde_json::Error), -} +#[error("{0}")] +pub struct EmbeddableError(#[from] Box); /// Trait for types that can be embedded. -/// The `embeddable` method returns a OneOrMany which contains strings for which embeddings will be generated by the embeddings builder. +/// The `embeddable` method returns a `OneOrMany` which contains strings for which embeddings will be generated by the embeddings builder. /// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. pub trait Embeddable { - type Error: std::error::Error; + type Error: std::error::Error + Sync + Send + 'static; fn embeddable(&self) -> Result, Self::Error>; } -/// Struct containing either a single item or a list of items of type T. -/// If a single item is present, `first` will contain it and `rest` will be empty. -/// If multiple items are present, `first` will contain the first item and `rest` will contain the rest. -/// IMPORTANT: this struct cannot be created with an empty vector. -/// OneOrMany objects can only be created using OneOrMany::one() or OneOrMany::many(). -#[derive(PartialEq, Eq, Debug, Clone)] -pub struct OneOrMany { - /// First item in the list. - first: T, - /// Rest of the items in the list. - rest: Vec, -} - -impl OneOrMany { - /// Get the first item in the list. - pub fn first(&self) -> T { - self.first.clone() - } - - /// Get the rest of the items in the list (excluding the first one). - pub fn rest(&self) -> Vec { - self.rest.clone() - } - - /// Create a OneOrMany object with a single item of any type. - pub fn one(item: T) -> Self { - OneOrMany { - first: item, - rest: vec![], - } - } - - /// Create a OneOrMany object with a vector of items of any type. - pub fn many(items: Vec) -> Self { - let mut iter = items.into_iter(); - OneOrMany { - first: match iter.next() { - Some(item) => item, - None => panic!("Cannot create OneOrMany with an empty vector."), - }, - rest: iter.collect(), - } - } - - /// Use the Iterator trait on OneOrMany - pub fn iter(&self) -> OneOrManyIterator { - OneOrManyIterator { - one_or_many: self, - index: 0, - } - } -} - -/// Implement Iterator for OneOrMany. -/// Iterates over all items in both `first` and `rest`. -/// Borrows the OneOrMany object that is being iterator over. -pub struct OneOrManyIterator<'a, T> { - one_or_many: &'a OneOrMany, - index: usize, -} - -impl<'a, T> Iterator for OneOrManyIterator<'a, T> { - type Item = &'a T; - - fn next(&mut self) -> Option { - let mut item = None; - if self.index == 0 { - item = Some(&self.one_or_many.first) - } else if self.index - 1 < self.one_or_many.rest.len() { - item = Some(&self.one_or_many.rest[self.index - 1]); - }; - - self.index += 1; - item - } -} - -/// Implement IntoIterator for OneOrMany. -/// Iterates over all items in both `first` and `rest`. -/// Takes ownership the OneOrMany object that is being iterator over. -pub struct OneOrManyIntoIterator { - one_or_many: OneOrMany, - index: usize, -} - -impl IntoIterator for OneOrMany { - type Item = T; - type IntoIter = OneOrManyIntoIterator; - - fn into_iter(self) -> OneOrManyIntoIterator { - OneOrManyIntoIterator { - one_or_many: self, - index: 0, - } - } -} - -impl Iterator for OneOrManyIntoIterator { - type Item = T; - - fn next(&mut self) -> Option { - let mut item = None; - if self.index == 0 { - item = Some(self.one_or_many.first()) - } else if self.index - 1 < self.one_or_many.rest.len() { - item = Some(self.one_or_many.rest[self.index - 1].clone()); - }; - - self.index += 1; - item - } -} - -/// Merge a list of OneOrMany items into a single OneOrMany item. -impl From>> for OneOrMany { - fn from(value: Vec>) -> Self { - let items = value - .into_iter() - .flat_map(|one_or_many| one_or_many.into_iter()) - .collect::>(); - - OneOrMany::many(items) - } -} - -////////////////////////////////////////////////////// -/// Implementations of Embeddable for common types /// -////////////////////////////////////////////////////// +// ================================================================ +// Implementations of Embeddable for common types +// ================================================================ impl Embeddable for String { type Error = EmbeddableError; @@ -258,102 +133,21 @@ impl Embeddable for serde_json::Value { fn embeddable(&self) -> Result, Self::Error> { Ok(OneOrMany::one( - serde_json::to_string(self).map_err(EmbeddableError::SerdeError)?, + serde_json::to_string(self).map_err(|e| EmbeddableError(Box::new(e)))?, )) } } impl Embeddable for Vec { - type Error = T::Error; + type Error = EmbeddableError; fn embeddable(&self) -> Result, Self::Error> { let items = self .iter() .map(|item| item.embeddable()) - .collect::, _>>()?; - - Ok(OneOrMany::from(items)) - } -} - -#[cfg(test)] -mod test { - use super::OneOrMany; - - #[test] - fn test_one_or_many_iter_single() { - let one_or_many = OneOrMany::one("hello".to_string()); - - assert_eq!(one_or_many.iter().count(), 1); - - one_or_many.iter().for_each(|i| { - assert_eq!(i, "hello"); - }); - } - - #[test] - fn test_one_or_many_iter() { - let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]); - - assert_eq!(one_or_many.iter().count(), 2); - - one_or_many.iter().enumerate().for_each(|(i, item)| { - if i == 0 { - assert_eq!(item, "hello"); - } - if i == 1 { - assert_eq!(item, "word"); - } - }); - } - - #[test] - fn test_one_or_many_into_iter_single() { - let one_or_many = OneOrMany::one("hello".to_string()); - - assert_eq!(one_or_many.clone().into_iter().count(), 1); - - one_or_many.into_iter().for_each(|i| { - assert_eq!(i, "hello".to_string()); - }); - } - - #[test] - fn test_one_or_many_into_iter() { - let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]); - - assert_eq!(one_or_many.clone().into_iter().count(), 2); - - one_or_many.into_iter().enumerate().for_each(|(i, item)| { - if i == 0 { - assert_eq!(item, "hello".to_string()); - } - if i == 1 { - assert_eq!(item, "word".to_string()); - } - }); - } - - #[test] - fn test_one_or_many_merge() { - let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]); - - let one_or_many_2 = OneOrMany::one("sup".to_string()); - - let merged = OneOrMany::from(vec![one_or_many_1, one_or_many_2]); - - assert_eq!(merged.iter().count(), 3); + .collect::, _>>() + .map_err(|e| EmbeddableError(Box::new(e)))?; - merged.iter().enumerate().for_each(|(i, item)| { - if i == 0 { - assert_eq!(item, "hello"); - } - if i == 1 { - assert_eq!(item, "word"); - } - if i == 2 { - assert_eq!(item, "sup"); - } - }); + OneOrMany::merge(items).map_err(|e| EmbeddableError(Box::new(e))) } } diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index ff284a05..81a980ef 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -1,5 +1,5 @@ //! The module defines the [EmbeddingModel] trait, which represents an embedding model that can -//! generate embeddings for documents. It also provides an implementation of the [EmbeddingsBuilder] +//! generate embeddings for documents. It also provides an implementation of the [embeddings::EmbeddingsBuilder] //! struct, which allows users to build collections of document embeddings using different embedding //! models and document sources. //! diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 37e720cb..a9eda7c3 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -6,3 +6,7 @@ pub mod builder; pub mod embeddable; pub mod embedding; + +pub use builder::EmbeddingsBuilder; +pub use embeddable::Embeddable; +pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 5c498c81..219d9be0 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -74,11 +74,11 @@ pub mod extractor; pub mod json_utils; pub mod providers; pub mod tool; +mod vec_utils; pub mod vector_store; // Re-export commonly used types and traits -pub use embeddings::builder::EmbeddingsBuilder; pub use embeddings::embeddable::Embeddable; -#[cfg(feature = "rig_derive")] +#[cfg(feature = "derive")] pub use rig_derive::Embeddable; diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 844849d7..ae874b21 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -13,9 +13,9 @@ use std::collections::HashMap; use crate::{ agent::AgentBuilder, completion::{self, CompletionError}, - embeddings::{self, embedding::EmbeddingError}, + embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, EmbeddingsBuilder, + json_utils, }; use schemars::JsonSchema; @@ -183,7 +183,7 @@ pub struct EmbeddingModel { ndims: usize, } -impl embeddings::embedding::EmbeddingModel for EmbeddingModel { +impl embeddings::EmbeddingModel for EmbeddingModel { const MAX_DOCUMENTS: usize = 96; fn ndims(&self) -> usize { @@ -193,7 +193,7 @@ impl embeddings::embedding::EmbeddingModel for EmbeddingModel { async fn embed_documents( &self, documents: Vec, - ) -> Result, EmbeddingError> { + ) -> Result, EmbeddingError> { let response = self .client .post("/v1/embed") @@ -222,7 +222,7 @@ impl embeddings::embedding::EmbeddingModel for EmbeddingModel { .embeddings .into_iter() .zip(documents.into_iter()) - .map(|(embedding, document)| embeddings::embedding::Embedding { + .map(|(embedding, document)| embeddings::Embedding { document, vec: embedding, }) diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 6b95a56a..c9ba9afa 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -11,9 +11,9 @@ use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, - embeddings::{self, embedding::EmbeddingError}, + embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, EmbeddingsBuilder, + json_utils, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -232,7 +232,7 @@ pub struct EmbeddingModel { ndims: usize, } -impl embeddings::embedding::EmbeddingModel for EmbeddingModel { +impl embeddings::EmbeddingModel for EmbeddingModel { const MAX_DOCUMENTS: usize = 1024; fn ndims(&self) -> usize { @@ -242,7 +242,7 @@ impl embeddings::embedding::EmbeddingModel for EmbeddingModel { async fn embed_documents( &self, documents: Vec, - ) -> Result, EmbeddingError> { + ) -> Result, EmbeddingError> { let response = self .client .post("/v1/embeddings") @@ -268,7 +268,7 @@ impl embeddings::embedding::EmbeddingModel for EmbeddingModel { .data .into_iter() .zip(documents.into_iter()) - .map(|(embedding, document)| embeddings::embedding::Embedding { + .map(|(embedding, document)| embeddings::Embedding { document, vec: embedding.embedding, }) diff --git a/rig-core/src/vec_utils.rs b/rig-core/src/vec_utils.rs new file mode 100644 index 00000000..d10a8608 --- /dev/null +++ b/rig-core/src/vec_utils.rs @@ -0,0 +1,191 @@ +/// Struct containing either a single item or a list of items of type T. +/// If a single item is present, `first` will contain it and `rest` will be empty. +/// If multiple items are present, `first` will contain the first item and `rest` will contain the rest. +/// IMPORTANT: this struct cannot be created with an empty vector. +/// OneOrMany objects can only be created using OneOrMany::from() or OneOrMany::try_from(). +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct OneOrMany { + /// First item in the list. + first: T, + /// Rest of the items in the list. + rest: Vec, +} + +/// Error type for when trying to create a OneOrMany object with an empty vector. +#[derive(Debug)] +pub struct EmptyListError; + +impl std::fmt::Display for EmptyListError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Cannot create OneOrMany with an empty vector.") + } +} +impl std::error::Error for EmptyListError {} + +impl OneOrMany { + /// Get the first item in the list. + pub fn first(&self) -> T { + self.first.clone() + } + + /// Get the rest of the items in the list (excluding the first one). + pub fn rest(&self) -> Vec { + self.rest.clone() + } + + /// Use the Iterator trait on OneOrMany + pub fn iter(&self) -> OneOrManyIterator { + OneOrManyIterator { + one_or_many: self, + index: 0, + } + } + + /// Create a OneOrMany object with a single item of any type. + pub fn one(item: T) -> Self { + OneOrMany { + first: item, + rest: vec![], + } + } + + /// Create a OneOrMany object with a vector of items of any type. + pub fn many(items: Vec) -> Result { + let mut iter = items.into_iter(); + Ok(OneOrMany { + first: match iter.next() { + Some(item) => item, + None => return Err(EmptyListError), + }, + rest: iter.collect(), + }) + } + + /// Merge a list of OneOrMany items into a single OneOrMany item. + pub fn merge(one_or_many_items: Vec>) -> Result { + let items = one_or_many_items + .into_iter() + .flat_map(|one_or_many| one_or_many.into_iter()) + .collect::>(); + + OneOrMany::many(items) + } +} + +/// Implement Iterator for OneOrMany. +/// Iterates over all items in both `first` and `rest`. +/// Borrows the OneOrMany object that is being iterator over. +pub struct OneOrManyIterator<'a, T> { + one_or_many: &'a OneOrMany, + index: usize, +} + +impl<'a, T> Iterator for OneOrManyIterator<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + let mut item = None; + if self.index == 0 { + item = Some(&self.one_or_many.first) + } else if self.index - 1 < self.one_or_many.rest.len() { + item = Some(&self.one_or_many.rest[self.index - 1]); + }; + + self.index += 1; + item + } +} + +/// Implement IntoIterator for OneOrMany. +/// Iterates over all items in both `first` and `rest`. +/// Takes ownership the OneOrMany object that is being iterator over. +impl IntoIterator for OneOrMany { + type Item = T; + type IntoIter = std::iter::Chain, std::vec::IntoIter>; + + fn into_iter(self) -> Self::IntoIter { + std::iter::once(self.first).chain(self.rest) + } +} + +#[cfg(test)] +mod test { + use super::OneOrMany; + + #[test] + fn test_one_or_many_iter_single() { + let one_or_many = OneOrMany::one("hello".to_string()); + + assert_eq!(one_or_many.iter().count(), 1); + + one_or_many.iter().for_each(|i| { + assert_eq!(i, "hello"); + }); + } + + #[test] + fn test_one_or_many_iter() { + let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + assert_eq!(one_or_many.iter().count(), 2); + + one_or_many.iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello"); + } + if i == 1 { + assert_eq!(item, "word"); + } + }); + } + + #[test] + fn test_one_or_many_into_iter_single() { + let one_or_many = OneOrMany::one("hello".to_string()); + + assert_eq!(one_or_many.clone().into_iter().count(), 1); + + one_or_many.into_iter().for_each(|i| { + assert_eq!(i, "hello".to_string()); + }); + } + + #[test] + fn test_one_or_many_into_iter() { + let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + assert_eq!(one_or_many.clone().into_iter().count(), 2); + + one_or_many.into_iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello".to_string()); + } + if i == 1 { + assert_eq!(item, "word".to_string()); + } + }); + } + + #[test] + fn test_one_or_many_merge() { + let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + let one_or_many_2 = OneOrMany::one("sup".to_string()); + + let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap(); + + assert_eq!(merged.iter().count(), 3); + + merged.iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello"); + } + if i == 1 { + assert_eq!(item, "word"); + } + if i == 2 { + assert_eq!(item, "sup"); + } + }); + } +} diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 9bbc85f1..bfe2bd29 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -8,7 +8,7 @@ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; use super::{VectorStoreError, VectorStoreIndex}; -use crate::embeddings::embedding::{Embedding, EmbeddingModel}; +use crate::embeddings::{Embedding, EmbeddingModel}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 6f112b81..396b5514 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -2,7 +2,7 @@ use futures::future::BoxFuture; use serde::Deserialize; use serde_json::Value; -use crate::embeddings::embedding::EmbeddingError; +use crate::embeddings::EmbeddingError; pub mod in_memory_store; diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 8f3ceee6..6466c2b9 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -3,11 +3,10 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use lancedb::index::vector::IvfPqIndexBuilder; -use rig::vector_store::VectorStoreIndex; use rig::{ - embeddings::embedding::EmbeddingModel, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - EmbeddingsBuilder, + vector_store::VectorStoreIndex, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; use serde::Deserialize; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 90a14fe0..5932dcd0 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -3,10 +3,9 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use rig::{ - embeddings::embedding::EmbeddingModel, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, - EmbeddingsBuilder, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 2d97ed80..70f0c8c5 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,10 +4,9 @@ use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - embeddings::embedding::EmbeddingModel, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, - EmbeddingsBuilder, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; use serde::Deserialize; diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 517bfea6..caba89d8 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -3,9 +3,9 @@ use std::env; use rig::{ embeddings::builder::DocumentEmbeddings, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, - EmbeddingsBuilder, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 17dda463..c3973092 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -2,7 +2,8 @@ use futures::StreamExt; use mongodb::bson::{self, doc}; use rig::{ - embeddings::{builder::DocumentEmbeddings, embedding::Embedding, embedding::EmbeddingModel}, + embeddings::builder::DocumentEmbeddings, + embeddings::{Embedding, EmbeddingModel}, vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; From c8f6646076ce27d3f7a4513ac603aba09436cd81 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 17 Oct 2024 15:00:32 -0400 Subject: [PATCH 37/47] fix: fix tests failing --- rig-core/rig-core-derive/src/embeddable.rs | 6 +++--- rig-core/src/embeddings/embeddable.rs | 13 ++++++++++--- rig-core/src/lib.rs | 3 ++- rig-core/tests/embeddable_macro.rs | 8 +++++--- 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index 563b8c3c..4c533d1e 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -48,13 +48,13 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu impl #impl_generics Embeddable for #name #ty_generics #where_clause { type Error = rig::embeddings::embeddable::EmbeddableError; - fn embeddable(&self) -> Result, Self::Error> { + fn embeddable(&self) -> Result, Self::Error> { #target_stream; - Ok(rig::embeddings::embeddable::OneOrMany::from( + rig::OneOrMany::merge( embed_targets.into_iter() .collect::, _>>()? - )) + ).map_err(rig::embeddings::embeddable::EmbeddableError::new) } } }; diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index e303ecda..accdf402 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -4,6 +4,7 @@ //! use std::env; //! //! use serde::{Deserialize, Serialize}; +//! use rig::OneOrMany; //! //! struct FakeDefinition { //! id: String, @@ -36,6 +37,12 @@ use crate::vec_utils::OneOrMany; #[error("{0}")] pub struct EmbeddableError(#[from] Box); +impl EmbeddableError { + pub fn new(error: E) -> Self { + EmbeddableError(Box::new(error)) + } +} + /// Trait for types that can be embedded. /// The `embeddable` method returns a `OneOrMany` which contains strings for which embeddings will be generated by the embeddings builder. /// If there is an error generating the list of strings, the method should return an error that implements `std::error::Error`. @@ -133,7 +140,7 @@ impl Embeddable for serde_json::Value { fn embeddable(&self) -> Result, Self::Error> { Ok(OneOrMany::one( - serde_json::to_string(self).map_err(|e| EmbeddableError(Box::new(e)))?, + serde_json::to_string(self).map_err(EmbeddableError::new)?, )) } } @@ -146,8 +153,8 @@ impl Embeddable for Vec { .iter() .map(|item| item.embeddable()) .collect::, _>>() - .map_err(|e| EmbeddableError(Box::new(e)))?; + .map_err(EmbeddableError::new)?; - OneOrMany::merge(items).map_err(|e| EmbeddableError(Box::new(e))) + OneOrMany::merge(items).map_err(EmbeddableError::new) } } diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 219d9be0..6997c6ff 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -74,11 +74,12 @@ pub mod extractor; pub mod json_utils; pub mod providers; pub mod tool; -mod vec_utils; +pub mod vec_utils; pub mod vector_store; // Re-export commonly used types and traits pub use embeddings::embeddable::Embeddable; +pub use vec_utils::OneOrMany; #[cfg(feature = "derive")] pub use rig_derive::Embeddable; diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/embeddable_macro.rs index 96e42c7f..d30c5bfa 100644 --- a/rig-core/tests/embeddable_macro.rs +++ b/rig-core/tests/embeddable_macro.rs @@ -1,10 +1,10 @@ -use rig::embeddings::embeddable::{EmbeddableError, OneOrMany}; -use rig::Embeddable; +use rig::embeddings::embeddable::EmbeddableError; +use rig::{Embeddable, OneOrMany}; use serde::Serialize; fn serialize(definition: Definition) -> Result, EmbeddableError> { Ok(OneOrMany::one( - serde_json::to_string(&definition).map_err(EmbeddableError::SerdeError)?, + serde_json::to_string(&definition).map_err(EmbeddableError::new)?, )) } @@ -143,6 +143,7 @@ fn test_multiple_embed_strings() { "35".to_string(), "40".to_string() ]) + .unwrap() ); assert_eq!(result.first(), "25".to_string()); @@ -181,5 +182,6 @@ fn test_multiple_embed_tags() { "35".to_string(), "40".to_string() ]) + .unwrap() ); } From 40f3c1868137e5ad52d97de544c7c22a0792f349 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 17 Oct 2024 15:21:55 -0400 Subject: [PATCH 38/47] test: add test on OneOrMany --- rig-core/rig-core-derive/src/basic.rs | 2 +- rig-core/rig-core-derive/src/embeddable.rs | 1 - rig-core/src/vec_utils.rs | 7 +++++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/rig-core/rig-core-derive/src/basic.rs b/rig-core/rig-core-derive/src/basic.rs index 7ac5bb47..942e6c7e 100644 --- a/rig-core/rig-core-derive/src/basic.rs +++ b/rig-core/rig-core-derive/src/basic.rs @@ -19,7 +19,7 @@ pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator syn::Resu )); } - // Determine whether the Embeddable::Kind should be SingleEmbedding or ManyEmbedding quote! { let mut embed_targets = #basic_targets; embed_targets.extend(#custom_targets) diff --git a/rig-core/src/vec_utils.rs b/rig-core/src/vec_utils.rs index d10a8608..d8435685 100644 --- a/rig-core/src/vec_utils.rs +++ b/rig-core/src/vec_utils.rs @@ -188,4 +188,11 @@ mod test { } }); } + + #[test] + fn test_one_or_many_error() { + assert!( + OneOrMany::::many(vec![]).is_err() + ) + } } From 68d88b6c7cf8594d49d770a1125783592d8e80cc Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 17 Oct 2024 15:22:14 -0400 Subject: [PATCH 39/47] style: cargo fmt --- rig-core/src/vec_utils.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/rig-core/src/vec_utils.rs b/rig-core/src/vec_utils.rs index d8435685..a487e71c 100644 --- a/rig-core/src/vec_utils.rs +++ b/rig-core/src/vec_utils.rs @@ -191,8 +191,6 @@ mod test { #[test] fn test_one_or_many_error() { - assert!( - OneOrMany::::many(vec![]).is_err() - ) + assert!(OneOrMany::::many(vec![]).is_err()) } } From 6ddc3c7efd2464cd666b5c26069c7923604a59dd Mon Sep 17 00:00:00 2001 From: Christophe Date: Thu, 17 Oct 2024 17:52:23 -0400 Subject: [PATCH 40/47] devops: Add cargo check for all features + doc check --- .github/workflows/ci.yaml | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e00d68cd..5f8514c2 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -55,6 +55,8 @@ jobs: - name: Run clippy action uses: clechasseur/rs-clippy-check@v3 + with: + args: --all-features test: name: stable / test @@ -79,4 +81,19 @@ jobs: uses: actions-rs/cargo@v1 with: command: nextest - args: run --all-features \ No newline at end of file + args: run --all-features + + doc: + name: stable / doc + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust stable + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + components: rust-docs + + - name: Run cargo doc + run: cargo doc --no-deps --all-features \ No newline at end of file From f799f94c5ff6da76bbdb6f68af6ee0824e0a9712 Mon Sep 17 00:00:00 2001 From: Christophe Date: Thu, 17 Oct 2024 17:55:18 -0400 Subject: [PATCH 41/47] devops: Fix missing dep --- .github/workflows/ci.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5f8514c2..d8aec427 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -95,5 +95,9 @@ jobs: with: components: rust-docs + # Required to compile rig-lancedb + - name: Install Protoc + uses: arduino/setup-protoc@v3 + - name: Run cargo doc run: cargo doc --no-deps --all-features \ No newline at end of file From 3558085108b15759e53b9afc083c12f584a1b36f Mon Sep 17 00:00:00 2001 From: Christophe Date: Thu, 17 Oct 2024 17:59:39 -0400 Subject: [PATCH 42/47] devops: Make cargo doc strict --- .github/workflows/ci.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d8aec427..8ab1324e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -100,4 +100,6 @@ jobs: uses: arduino/setup-protoc@v3 - name: Run cargo doc - run: cargo doc --no-deps --all-features \ No newline at end of file + run: cargo doc --no-deps --all-features + env: + RUSTDOCFLAGS: -D warnings \ No newline at end of file From dc248edd58f58046f40ea8d366d907bad48d45cd Mon Sep 17 00:00:00 2001 From: Christophe Date: Thu, 17 Oct 2024 18:02:06 -0400 Subject: [PATCH 43/47] docs: Fix docstring links --- rig-core/src/completion.rs | 4 ++-- rig-core/src/lib.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index 4383e561..4f868786 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -439,14 +439,14 @@ impl CompletionRequestBuilder { } /// Sets the max tokens for the completion request. - /// Only required for: [ Anthropic ] + /// Note: This is required if using Anthropic pub fn max_tokens(mut self, max_tokens: u64) -> Self { self.max_tokens = Some(max_tokens); self } /// Sets the max tokens for the completion request. - /// Only required for: [ Anthropic ] + /// Note: This is required if using Anthropic pub fn max_tokens_opt(mut self, max_tokens: Option) -> Self { self.max_tokens = max_tokens; self diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 86c25209..79d47079 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -54,7 +54,7 @@ //! Rig provides a common interface for working with vector stores and indexes. Specifically, the library //! provides the [VectorStore](crate::vector_store::VectorStore) and [VectorStoreIndex](crate::vector_store::VectorStoreIndex) //! traits, which can be implemented to define vector stores and indices respectively. -//! Those can then be used as the knowledgebase for a [RagAgent](crate::rag::RagAgent), or +//! Those can then be used as the knowledgebase for a RAG enabled [Agent](crate::agent::Agent), or //! as a source of context documents in a custom architecture that use multiple LLMs or agents. //! //! # Integrations From 4bc7d07e41de81c562472db848e613d2fdb0292d Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 17 Oct 2024 18:02:11 -0400 Subject: [PATCH 44/47] docs&fix: fix doc strings, implement iter_mut for OneOrMany --- rig-core/src/embeddings/embeddable.rs | 2 +- rig-core/src/embeddings/embedding.rs | 2 +- rig-core/src/lib.rs | 8 +- rig-core/src/{vec_utils.rs => one_or_many.rs} | 143 ++++++++++++++---- 4 files changed, 118 insertions(+), 37 deletions(-) rename rig-core/src/{vec_utils.rs => one_or_many.rs} (59%) diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index accdf402..f5a69fd6 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -29,7 +29,7 @@ //! } //! ``` -use crate::vec_utils::OneOrMany; +use crate::one_or_many::OneOrMany; /// Error type used for when the `embeddable` method fails. /// Used by default implementations of `Embeddable` for common types. diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index 81a980ef..735284a8 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -1,5 +1,5 @@ //! The module defines the [EmbeddingModel] trait, which represents an embedding model that can -//! generate embeddings for documents. It also provides an implementation of the [embeddings::EmbeddingsBuilder] +//! generate embeddings for documents. It also provides an implementation of the [crate::embeddings::EmbeddingsBuilder] //! struct, which allows users to build collections of document embeddings using different embedding //! models and document sources. //! diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 6997c6ff..a4850791 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -52,8 +52,8 @@ //! //! ## Vector stores and indexes //! Rig provides a common interface for working with vector stores and indexes. Specifically, the library -//! provides the [VectorStore](crate::vector_store::VectorStore) and [VectorStoreIndex](crate::vector_store::VectorStoreIndex) -//! traits, which can be implemented to define vector stores and indices respectively. +//! provides the [VectorStoreIndex](crate::vector_store::VectorStoreIndex) +//! trait, which can be implemented to define vector stores and indices. //! Those can then be used as the knowledgebase for a [RagAgent](crate::rag::RagAgent), or //! as a source of context documents in a custom architecture that use multiple LLMs or agents. //! @@ -72,14 +72,14 @@ pub mod completion; pub mod embeddings; pub mod extractor; pub mod json_utils; +pub mod one_or_many; pub mod providers; pub mod tool; -pub mod vec_utils; pub mod vector_store; // Re-export commonly used types and traits pub use embeddings::embeddable::Embeddable; -pub use vec_utils::OneOrMany; +pub use one_or_many::OneOrMany; #[cfg(feature = "derive")] pub use rig_derive::Embeddable; diff --git a/rig-core/src/vec_utils.rs b/rig-core/src/one_or_many.rs similarity index 59% rename from rig-core/src/vec_utils.rs rename to rig-core/src/one_or_many.rs index a487e71c..23ece94f 100644 --- a/rig-core/src/vec_utils.rs +++ b/rig-core/src/one_or_many.rs @@ -33,14 +33,6 @@ impl OneOrMany { self.rest.clone() } - /// Use the Iterator trait on OneOrMany - pub fn iter(&self) -> OneOrManyIterator { - OneOrManyIterator { - one_or_many: self, - index: 0, - } - } - /// Create a OneOrMany object with a single item of any type. pub fn one(item: T) -> Self { OneOrMany { @@ -70,41 +62,102 @@ impl OneOrMany { OneOrMany::many(items) } + + pub fn iter(&self) -> Iter { + Iter { + first: Some(&self.first), + rest: self.rest.iter(), + } + } + + pub fn iter_mut(&mut self) -> IterMut<'_, T> { + IterMut { + first: Some(&mut self.first), + rest: self.rest.iter_mut(), + } + } } -/// Implement Iterator for OneOrMany. -/// Iterates over all items in both `first` and `rest`. -/// Borrows the OneOrMany object that is being iterator over. -pub struct OneOrManyIterator<'a, T> { - one_or_many: &'a OneOrMany, - index: usize, +// ================================================================ +// Implementations of Iterator for OneOrMany +// - OneOrMany::iter() -> iterate over references of T objects +// - OneOrMany::into_iter() -> iterate over owned T objects +// - OneOrMany::iter_mut() -> iterate over mutable references of T objects +// ================================================================ + +/// Struct returned by call to `OneOrMany::iter()`. +pub struct Iter<'a, T> { + // References. + first: Option<&'a T>, + rest: std::slice::Iter<'a, T>, } -impl<'a, T> Iterator for OneOrManyIterator<'a, T> { +/// Implement `Iterator` for `Iter`. +/// The Item type of the `Iterator` trait is a reference of `T`. +impl<'a, T> Iterator for Iter<'a, T> { type Item = &'a T; fn next(&mut self) -> Option { - let mut item = None; - if self.index == 0 { - item = Some(&self.one_or_many.first) - } else if self.index - 1 < self.one_or_many.rest.len() { - item = Some(&self.one_or_many.rest[self.index - 1]); - }; - - self.index += 1; - item + if let Some(first) = self.first.take() { + Some(first) + } else { + self.rest.next() + } } } -/// Implement IntoIterator for OneOrMany. -/// Iterates over all items in both `first` and `rest`. -/// Takes ownership the OneOrMany object that is being iterator over. +/// Struct returned by call to `OneOrMany::into_iter()`. +pub struct IntoIter { + // Owned. + first: Option, + rest: std::vec::IntoIter, +} + +/// Implement `Iterator` for `IntoIter`. impl IntoIterator for OneOrMany { type Item = T; - type IntoIter = std::iter::Chain, std::vec::IntoIter>; + type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { - std::iter::once(self.first).chain(self.rest) + IntoIter { + first: Some(self.first), + rest: self.rest.into_iter(), + } + } +} + +/// Implement `Iterator` for `IntoIter`. +/// The Item type of the `Iterator` trait is an owned `T`. +impl Iterator for IntoIter { + type Item = T; + + fn next(&mut self) -> Option { + if let Some(first) = self.first.take() { + Some(first) + } else { + self.rest.next() + } + } +} + +/// Struct returned by call to `OneOrMany::iter_mut()`. +pub struct IterMut<'a, T> { + // Mutable references. + first: Option<&'a mut T>, + rest: std::slice::IterMut<'a, T>, +} + +// Implement `Iterator` for `IterMut`. +// The Item type of the `Iterator` trait is a mutable reference of `OneOrMany`. +impl<'a, T> Iterator for IterMut<'a, T> { + type Item = &'a mut T; + + fn next(&mut self) -> Option { + if let Some(first) = self.first.take() { + Some(first) + } else { + self.rest.next() + } } } @@ -113,7 +166,7 @@ mod test { use super::OneOrMany; #[test] - fn test_one_or_many_iter_single() { + fn test_single() { let one_or_many = OneOrMany::one("hello".to_string()); assert_eq!(one_or_many.iter().count(), 1); @@ -124,7 +177,7 @@ mod test { } #[test] - fn test_one_or_many_iter() { + fn test() { let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); assert_eq!(one_or_many.iter().count(), 2); @@ -189,6 +242,34 @@ mod test { }); } + #[test] + fn test_mut_single() { + let mut one_or_many = OneOrMany::one("hello".to_string()); + + assert_eq!(one_or_many.iter_mut().count(), 1); + + one_or_many.iter_mut().for_each(|i| { + assert_eq!(i, "hello"); + }); + } + + #[test] + fn test_mut() { + let mut one_or_many = + OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + assert_eq!(one_or_many.iter_mut().count(), 2); + + one_or_many.iter_mut().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello"); + } + if i == 1 { + assert_eq!(item, "word"); + } + }); + } + #[test] fn test_one_or_many_error() { assert!(OneOrMany::::many(vec![]).is_err()) From 4d2ffdb0f8a8c6e425a99028c62816f7bbd5ff3f Mon Sep 17 00:00:00 2001 From: Garance Date: Fri, 18 Oct 2024 09:57:17 -0400 Subject: [PATCH 45/47] fix: update borrow and owning of macro --- rig-core/rig-core-derive/src/basic.rs | 22 +++++++++------------- rig-core/rig-core-derive/src/custom.rs | 10 ++++------ rig-core/rig-core-derive/src/embeddable.rs | 4 ++-- rig-core/src/one_or_many.rs | 3 ++- rig-core/tests/embeddable_macro.rs | 19 ------------------- 5 files changed, 17 insertions(+), 41 deletions(-) diff --git a/rig-core/rig-core-derive/src/basic.rs b/rig-core/rig-core-derive/src/basic.rs index 942e6c7e..86bb13ad 100644 --- a/rig-core/rig-core-derive/src/basic.rs +++ b/rig-core/rig-core-derive/src/basic.rs @@ -3,19 +3,15 @@ use syn::{parse_quote, Attribute, DataStruct, Meta}; use crate::EMBED; /// Finds and returns fields with simple #[embed] attribute tags only. -pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator { - data_struct.fields.clone().into_iter().filter(|field| { - field - .attrs - .clone() - .into_iter() - .any(|attribute| match attribute { - Attribute { - meta: Meta::Path(path), - .. - } => path.is_ident(EMBED), - _ => false, - }) +pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator { + data_struct.fields.iter().filter(|field| { + field.attrs.iter().any(|attribute| match attribute { + Attribute { + meta: Meta::Path(path), + .. + } => path.is_ident(EMBED), + _ => false, + }) }) } diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs index de754372..f29ed20e 100644 --- a/rig-core/rig-core-derive/src/custom.rs +++ b/rig-core/rig-core-derive/src/custom.rs @@ -9,19 +9,17 @@ const EMBED_WITH: &str = "embed_with"; /// Also returns the "..." part of the tag (ie. the custom function). pub(crate) fn custom_embed_fields( data_struct: &syn::DataStruct, -) -> syn::Result> { +) -> syn::Result> { data_struct .fields - .clone() - .into_iter() + .iter() .filter_map(|field| { field .attrs - .clone() - .into_iter() + .iter() .filter_map(|attribute| match attribute.is_custom() { Ok(true) => match attribute.expand_tag() { - Ok(path) => Some(Ok((field.clone(), path))), + Ok(path) => Some(Ok((field, path))), Err(e) => Some(Err(e)), }, Ok(false) => None, diff --git a/rig-core/rig-core-derive/src/embeddable.rs b/rig-core/rig-core-derive/src/embeddable.rs index 13299ba6..8336c9e0 100644 --- a/rig-core/rig-core-derive/src/embeddable.rs +++ b/rig-core/rig-core-derive/src/embeddable.rs @@ -76,7 +76,7 @@ impl StructParser for DataStruct { .map(|field| { add_struct_bounds(generics, &field.ty); - let field_name = field.ident; + let field_name = &field.ident; quote! { self.#field_name @@ -106,7 +106,7 @@ impl StructParser for DataStruct { // Iterate over every field tagged with #[embed(embed_with = "...")] .into_iter() .map(|(field, custom_func_path)| { - let field_name = field.ident; + let field_name = &field.ident; quote! { #custom_func_path(self.#field_name.clone()) diff --git a/rig-core/src/one_or_many.rs b/rig-core/src/one_or_many.rs index 23ece94f..08165873 100644 --- a/rig-core/src/one_or_many.rs +++ b/rig-core/src/one_or_many.rs @@ -262,7 +262,8 @@ mod test { one_or_many.iter_mut().enumerate().for_each(|(i, item)| { if i == 0 { - assert_eq!(item, "hello"); + item.push_str(" world"); + assert_eq!(item, "hello world"); } if i == 1 { assert_eq!(item, "word"); diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/embeddable_macro.rs index d30c5bfa..d3f93966 100644 --- a/rig-core/tests/embeddable_macro.rs +++ b/rig-core/tests/embeddable_macro.rs @@ -35,11 +35,6 @@ fn test_custom_embed() { }, }; - println!( - "FakeDefinition: {}, {}", - fake_definition.id, fake_definition.word - ); - assert_eq!( fake_definition.embeddable().unwrap(), OneOrMany::one( @@ -70,11 +65,6 @@ fn test_custom_and_basic_embed() { }, }; - println!( - "FakeDefinition: {}, {}", - fake_definition.id, fake_definition.word - ); - assert_eq!( fake_definition.embeddable().unwrap().first(), "house".to_string() @@ -104,11 +94,6 @@ fn test_single_embed() { definition: definition.clone(), }; - println!( - "FakeDefinition3: {}, {}", - fake_definition.id, fake_definition.word - ); - assert_eq!( fake_definition.embeddable().unwrap(), OneOrMany::one(definition) @@ -131,8 +116,6 @@ fn test_multiple_embed_strings() { employee_ages: vec![25, 30, 35, 40], }; - println!("Company: {}, {}", company.id, company.company); - let result = company.embeddable().unwrap(); assert_eq!( @@ -171,8 +154,6 @@ fn test_multiple_embed_tags() { employee_ages: vec![25, 30, 35, 40], }; - println!("Company2: {}", company.id); - assert_eq!( company.embeddable().unwrap(), OneOrMany::many(vec![ From 6f0422567d1d206e1cf50f8649f68043dd720d73 Mon Sep 17 00:00:00 2001 From: Garance Date: Fri, 18 Oct 2024 10:24:04 -0400 Subject: [PATCH 46/47] clippy: add back print statements --- rig-core/tests/embeddable_macro.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/rig-core/tests/embeddable_macro.rs b/rig-core/tests/embeddable_macro.rs index d3f93966..cbc76c80 100644 --- a/rig-core/tests/embeddable_macro.rs +++ b/rig-core/tests/embeddable_macro.rs @@ -35,6 +35,11 @@ fn test_custom_embed() { }, }; + println!( + "FakeDefinition: {}, {}", + fake_definition.id, fake_definition.word + ); + assert_eq!( fake_definition.embeddable().unwrap(), OneOrMany::one( @@ -65,6 +70,11 @@ fn test_custom_and_basic_embed() { }, }; + println!( + "FakeDefinition: {}, {}", + fake_definition.id, fake_definition.word + ); + assert_eq!( fake_definition.embeddable().unwrap().first(), "house".to_string() @@ -93,6 +103,10 @@ fn test_single_embed() { word: "house".to_string(), definition: definition.clone(), }; + println!( + "FakeDefinition3: {}, {}", + fake_definition.id, fake_definition.word + ); assert_eq!( fake_definition.embeddable().unwrap(), @@ -116,6 +130,8 @@ fn test_multiple_embed_strings() { employee_ages: vec![25, 30, 35, 40], }; + println!("Company: {}, {}", company.id, company.company); + let result = company.embeddable().unwrap(); assert_eq!( @@ -154,6 +170,8 @@ fn test_multiple_embed_tags() { employee_ages: vec![25, 30, 35, 40], }; + println!("Company: {}", company.id); + assert_eq!( company.embeddable().unwrap(), OneOrMany::many(vec![ From 485ad3bbe55af7d2dcf0b0256aa9b3e6dfb7d47b Mon Sep 17 00:00:00 2001 From: Garance Date: Fri, 18 Oct 2024 11:49:19 -0400 Subject: [PATCH 47/47] refactor: use thiserror for OneOtMany::EmptyListError --- rig-core/src/one_or_many.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/rig-core/src/one_or_many.rs b/rig-core/src/one_or_many.rs index 08165873..87435537 100644 --- a/rig-core/src/one_or_many.rs +++ b/rig-core/src/one_or_many.rs @@ -12,16 +12,10 @@ pub struct OneOrMany { } /// Error type for when trying to create a OneOrMany object with an empty vector. -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] +#[error("Cannot create OneOrMany with an empty vector.")] pub struct EmptyListError; -impl std::fmt::Display for EmptyListError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Cannot create OneOrMany with an empty vector.") - } -} -impl std::error::Error for EmptyListError {} - impl OneOrMany { /// Get the first item in the list. pub fn first(&self) -> T {