Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(embeddings): Embeddable derive macro #59

Merged
merged 52 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
b7d146f
feat: setup derive macro
marieaurore123 Oct 2, 2024
5904734
test: test out writing embeddable macro
marieaurore123 Oct 3, 2024
ee9b5c3
test: continue testing custom macro implementation
marieaurore123 Oct 4, 2024
c8c1e9c
feat: macro generate trait bounds
marieaurore123 Oct 6, 2024
dff0aeb
refactor: split up macro into multiple files
marieaurore123 Oct 7, 2024
0d10011
refactor: move macro derive crate inside rig-core
marieaurore123 Oct 7, 2024
982a17b
Merge branch 'main' into feat(embeddings)/derive-macro
marieaurore123 Oct 8, 2024
79754aa
feat: replace embedding logic with new embeddable trait and macro
marieaurore123 Oct 8, 2024
3438fab
refactor: refactor rag examples, delete document embedding struct
marieaurore123 Oct 8, 2024
597e6c3
feat: remove document embedding from in memory store
marieaurore123 Oct 9, 2024
6c7ab8d
revert: revert vector store to main
marieaurore123 Oct 9, 2024
d7d2c19
Merge branch 'refactor(vector-store)/in-memeory-vector-store' into fe…
marieaurore123 Oct 9, 2024
bb712e3
docs: update emebddings builder docstrings
marieaurore123 Oct 9, 2024
efa2b65
refactor: derive macro
marieaurore123 Oct 10, 2024
8325164
feat: add error type on embeddable trait
marieaurore123 Oct 11, 2024
de022c4
refactor: move embeddings to its own module and seperate embeddable
marieaurore123 Oct 11, 2024
220d9fc
refactor: split up macro into more files, fix all imports
marieaurore123 Oct 11, 2024
f5e60f5
Merge branch 'feat/embeddings-overhaul' into feat(embeddings)/derive-…
marieaurore123 Oct 15, 2024
5a8c361
feat: handle tools with embeddingsbuilder
marieaurore123 Oct 15, 2024
dc89e54
bug(macro): fix error when embed tags missing
marieaurore123 Oct 15, 2024
ae66d08
style: cargo fmt
marieaurore123 Oct 15, 2024
4305952
fix(tests): clippy
marieaurore123 Oct 15, 2024
24e3b98
docs&revert: revert embeddable trait error type, add docstrings
marieaurore123 Oct 15, 2024
a7dbf6c
style: cargo clippy
marieaurore123 Oct 15, 2024
886ebcb
clippy(lancedb): fix unused function error
marieaurore123 Oct 15, 2024
79dea45
fix(test): remove useless assert false statement
marieaurore123 Oct 15, 2024
6362344
cleanup: split up branch into 2 branches for readability
marieaurore123 Oct 16, 2024
b5e1bf3
cleanup: revert certain changes during branch split
marieaurore123 Oct 16, 2024
7caf134
docs: revert doc string
marieaurore123 Oct 16, 2024
fb979ec
refactor: use OneOrMany in Embbedable trait, make derive macro crate …
marieaurore123 Oct 16, 2024
690027c
tests: add some more tests
marieaurore123 Oct 16, 2024
cca6059
clippy: cargo clippy
marieaurore123 Oct 17, 2024
f785b8c
docs: add docstring to oneormany
marieaurore123 Oct 17, 2024
0e2ade9
fix(macro): update error handling
marieaurore123 Oct 17, 2024
a98769c
refactor: reexport EmbeddingsBuilder in rig and update imports
marieaurore123 Oct 17, 2024
067894c
feat: implement IntoIterator and Iterator for OneOrMany
marieaurore123 Oct 17, 2024
32bcc61
refactor: rename from methods
marieaurore123 Oct 17, 2024
564bef4
tests: fix failing tests
marieaurore123 Oct 17, 2024
04f1f3e
refactor&fix: make PR review changes
marieaurore123 Oct 17, 2024
c8f6646
fix: fix tests failing
marieaurore123 Oct 17, 2024
40f3c18
test: add test on OneOrMany
marieaurore123 Oct 17, 2024
68d88b6
style: cargo fmt
marieaurore123 Oct 17, 2024
6ddc3c7
devops: Add cargo check for all features + doc check
cvauclair Oct 17, 2024
f799f94
devops: Fix missing dep
cvauclair Oct 17, 2024
3558085
devops: Make cargo doc strict
cvauclair Oct 17, 2024
dc248ed
docs: Fix docstring links
cvauclair Oct 17, 2024
4bc7d07
docs&fix: fix doc strings, implement iter_mut for OneOrMany
marieaurore123 Oct 17, 2024
093c434
Merge pull request #63 from 0xPlaygrounds/devops/ci-fix
cvauclair Oct 17, 2024
4d2ffdb
fix: update borrow and owning of macro
marieaurore123 Oct 18, 2024
6f04225
clippy: add back print statements
marieaurore123 Oct 18, 2024
bdd98e5
Merge branch 'main' into feat(embeddings)/derive-macro
marieaurore123 Oct 18, 2024
485ad3b
refactor: use thiserror for OneOtMany::EmptyListError
marieaurore123 Oct 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
769 changes: 405 additions & 364 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[workspace]
resolver = "2"
members = [
"rig-core", "rig-lancedb",
"rig-core", "rig-core/rig-core-derive",
"rig-mongodb",
"rig-lancedb"
]
8 changes: 8 additions & 0 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@ futures = "0.3.29"
ordered-float = "4.2.0"
schemars = "0.8.16"
thiserror = "1.0.61"
rig-derive = { path = "./rig-core-derive", optional = true }

[dev-dependencies]
anyhow = "1.0.75"
tokio = { version = "1.34.0", features = ["full"] }
tracing-subscriber = "0.3.18"

[features]
rig_derive = ["dep:rig-derive"]
marieaurore123 marked this conversation as resolved.
Show resolved Hide resolved

[[test]]
name = "embeddable_macro"
required-features = ["rig_derive"]
3 changes: 2 additions & 1 deletion rig-core/examples/calculator_chatbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use anyhow::Result;
use rig::{
cli_chatbot::cli_chatbot,
completion::ToolDefinition,
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
embeddings::builder::DocumentEmbeddings,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
tool::{Tool, ToolEmbedding, ToolSet},
vector_store::in_memory_store::InMemoryVectorStore,
EmbeddingsBuilder,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
Expand Down
3 changes: 2 additions & 1 deletion rig-core/examples/rag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use std::env;

use rig::{
completion::Prompt,
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
embeddings::builder::DocumentEmbeddings,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::in_memory_store::InMemoryVectorStore,
EmbeddingsBuilder,
};

#[tokio::main]
Expand Down
3 changes: 2 additions & 1 deletion rig-core/examples/rag_dynamic_tools.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use anyhow::Result;
use rig::{
completion::{Prompt, ToolDefinition},
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
embeddings::builder::DocumentEmbeddings,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
tool::{Tool, ToolEmbedding, ToolSet},
vector_store::in_memory_store::InMemoryVectorStore,
EmbeddingsBuilder,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
Expand Down
3 changes: 2 additions & 1 deletion rig-core/examples/vector_search.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::env;

use rig::{
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
embeddings::builder::DocumentEmbeddings,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
EmbeddingsBuilder,
};

#[tokio::main]
Expand Down
3 changes: 2 additions & 1 deletion rig-core/examples/vector_search_cohere.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::env;

use rig::{
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
embeddings::builder::DocumentEmbeddings,
providers::cohere::{Client, EMBED_ENGLISH_V3},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
EmbeddingsBuilder,
};

#[tokio::main]
Expand Down
13 changes: 13 additions & 0 deletions rig-core/rig-core-derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "rig-derive"
version = "0.1.0"
edition = "2021"

[dependencies]
indoc = "2.0.5"
proc-macro2 = { version = "1.0.87", features = ["proc-macro"] }
quote = "1.0.37"
syn = { version = "2.0.79", features = ["full"]}

[lib]
proc-macro = true
29 changes: 29 additions & 0 deletions rig-core/rig-core-derive/src/basic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use syn::{parse_quote, Attribute, DataStruct, Meta};

use crate::EMBED;

/// Finds and returns fields with simple #[embed] attribute tags only.
pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item = syn::Field> {
data_struct.fields.clone().into_iter().filter(|field| {
marieaurore123 marked this conversation as resolved.
Show resolved Hide resolved
field
.attrs
.clone()
.into_iter()
marieaurore123 marked this conversation as resolved.
Show resolved Hide resolved
.any(|attribute| match attribute {
Attribute {
meta: Meta::Path(path),
..
} => path.is_ident(EMBED),
_ => false,
})
})
}

// Adds bounds to where clause that force all fields tagged with #[embed] to implement the Embeddable trait.
pub(crate) fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) {
let where_clause = generics.make_where_clause();

where_clause.predicates.push(parse_quote! {
#field_type: Embeddable
});
}
125 changes: 125 additions & 0 deletions rig-core/rig-core-derive/src/custom.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use quote::ToTokens;
use syn::{meta::ParseNestedMeta, ExprPath};

use crate::EMBED;

const EMBED_WITH: &str = "embed_with";

/// Finds and returns fields with #[embed(embed_with = "...")] attribute tags only.
/// Also returns the "..." part of the tag (ie. the custom function).
pub(crate) fn custom_embed_fields(
data_struct: &syn::DataStruct,
) -> syn::Result<Vec<(syn::Field, syn::ExprPath)>> {
data_struct
.fields
.clone()
.into_iter()
marieaurore123 marked this conversation as resolved.
Show resolved Hide resolved
.filter_map(|field| {
field
.attrs
.clone()
.into_iter()
marieaurore123 marked this conversation as resolved.
Show resolved Hide resolved
.filter_map(|attribute| match attribute.is_custom() {
Ok(true) => match attribute.expand_tag() {
Ok(path) => Some(Ok((field.clone(), path))),
Err(e) => Some(Err(e)),
},
Ok(false) => None,
Err(e) => Some(Err(e)),
})
.next()
})
.collect::<Result<Vec<_>, _>>()
}

trait CustomAttributeParser {
// Determine if field is tagged with an #[embed(embed_with = "...")] attribute.
fn is_custom(&self) -> syn::Result<bool>;

// Get the "..." part of the #[embed(embed_with = "...")] attribute.
// Ex: If attribute is tagged with #[embed(embed_with = "my_embed")], returns "my_embed".
fn expand_tag(&self) -> syn::Result<syn::ExprPath>;
}

impl CustomAttributeParser for syn::Attribute {
fn is_custom(&self) -> syn::Result<bool> {
// Check that the attribute is a list.
match &self.meta {
syn::Meta::List(meta) => {
if meta.tokens.is_empty() {
return Ok(false);
}
}
_ => return Ok(false),
};

// Check the first attribute tag (the first "embed")
if !self.path().is_ident(EMBED) {
return Ok(false);
}

self.parse_nested_meta(|meta| {
// Parse the meta attribute as an expression. Need this to compile.
meta.value()?.parse::<syn::Expr>()?;

if meta.path.is_ident(EMBED_WITH) {
Ok(())
} else {
let path = meta.path.to_token_stream().to_string().replace(' ', "");
Err(syn::Error::new_spanned(
meta.path,
format_args!("unknown embedding field attribute `{}`", path),
))
}
})?;

Ok(true)
}

fn expand_tag(&self) -> syn::Result<syn::ExprPath> {
fn function_path(meta: &ParseNestedMeta<'_>) -> syn::Result<ExprPath> {
// #[embed(embed_with = "...")]
let expr = meta.value()?.parse::<syn::Expr>().unwrap();
let mut value = &expr;
while let syn::Expr::Group(e) = value {
value = &e.expr;
}
let string = if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = value
{
let suffix = lit_str.suffix();
if !suffix.is_empty() {
return Err(syn::Error::new_spanned(
lit_str,
format!("unexpected suffix `{}` on string literal", suffix),
));
}
lit_str.clone()
} else {
return Err(syn::Error::new_spanned(
value,
format!(
"expected {} attribute to be a string: `{} = \"...\"`",
EMBED_WITH, EMBED_WITH
),
));
};

string.parse()
}

let mut custom_func_path = None;

self.parse_nested_meta(|meta| match function_path(&meta) {
Ok(path) => {
custom_func_path = Some(path);
Ok(())
}
Err(e) => Err(e),
})?;

Ok(custom_func_path.unwrap())
}
}
134 changes: 134 additions & 0 deletions rig-core/rig-core-derive/src/embeddable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use proc_macro2::TokenStream;
use quote::quote;
use syn::DataStruct;

use crate::{
basic::{add_struct_bounds, basic_embed_fields},
custom::custom_embed_fields,
};

pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result<TokenStream> {
let name = &input.ident;
let data = &input.data;
let generics = &mut input.generics;

let target_stream = match data {
syn::Data::Struct(data_struct) => {
let (basic_targets, basic_target_size) = data_struct.basic(generics);
let (custom_targets, custom_target_size) = data_struct.custom()?;

// If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream.
// ie. do not implement Embeddable trait for the struct.
if basic_target_size + custom_target_size == 0 {
return Err(syn::Error::new_spanned(
name,
"Add at least one field tagged with #[embed] or #[embed(embed_with = \"...\")].",
));
}

// Determine whether the Embeddable::Kind should be SingleEmbedding or ManyEmbedding
quote! {
let mut embed_targets = #basic_targets;
embed_targets.extend(#custom_targets)
}
}
_ => {
return Err(syn::Error::new_spanned(
input,
"Embeddable derive macro should only be used on structs",
))
}
};

let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

let gen = quote! {
// Note: Embeddable trait is imported with the macro.

impl #impl_generics Embeddable for #name #ty_generics #where_clause {
type Error = rig::embeddings::embeddable::EmbeddableError;

fn embeddable(&self) -> Result<rig::embeddings::embeddable::OneOrMany<String>, Self::Error> {
#target_stream;

Ok(rig::embeddings::embeddable::OneOrMany::from(
embed_targets.into_iter()
.collect::<Result<Vec<_>, _>>()?
))
}
}
};

Ok(gen)
}

trait StructParser {
// Handles fields tagged with #[embed]
fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize);

// Handles fields tagged with #[embed(embed_with = "...")]
fn custom(&self) -> syn::Result<(TokenStream, usize)>;
}

impl StructParser for DataStruct {
fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize) {
let embed_targets = basic_embed_fields(self)
// Iterate over every field tagged with #[embed]
.map(|field| {
add_struct_bounds(generics, &field.ty);

let field_name = field.ident;

quote! {
self.#field_name
}
})
.collect::<Vec<_>>();

if !embed_targets.is_empty() {
(
quote! {
vec![#(#embed_targets.embeddable()),*]
},
embed_targets.len(),
)
} else {
(
quote! {
vec![]
},
0,
)
}
}

fn custom(&self) -> syn::Result<(TokenStream, usize)> {
let embed_targets = custom_embed_fields(self)?
// Iterate over every field tagged with #[embed(embed_with = "...")]
.into_iter()
.map(|(field, custom_func_path)| {
let field_name = field.ident;

quote! {
#custom_func_path(self.#field_name.clone())
}
})
.collect::<Vec<_>>();

Ok(if !embed_targets.is_empty() {
(
quote! {
vec![#(#embed_targets),*]
},
embed_targets.len(),
)
} else {
(
quote! {
vec![]
},
0,
)
})
}
}
Loading
Loading