Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: embeddings overhaul #120

Merged
merged 106 commits into from
Nov 29, 2024
Merged
Changes from 1 commit
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
b7d146f
feat: setup derive macro
marieaurore123 Oct 2, 2024
5904734
test: test out writing embeddable macro
marieaurore123 Oct 3, 2024
ee9b5c3
test: continue testing custom macro implementation
marieaurore123 Oct 4, 2024
c8c1e9c
feat: macro generate trait bounds
marieaurore123 Oct 6, 2024
dff0aeb
refactor: split up macro into multiple files
marieaurore123 Oct 7, 2024
0d10011
refactor: move macro derive crate inside rig-core
marieaurore123 Oct 7, 2024
982a17b
Merge branch 'main' into feat(embeddings)/derive-macro
marieaurore123 Oct 8, 2024
79754aa
feat: replace embedding logic with new embeddable trait and macro
marieaurore123 Oct 8, 2024
3438fab
refactor: refactor rag examples, delete document embedding struct
marieaurore123 Oct 8, 2024
597e6c3
feat: remove document embedding from in memory store
marieaurore123 Oct 9, 2024
5407772
refactor: remove DocumentEmbeddings from in memory vector store
marieaurore123 Oct 9, 2024
c61685e
refactor(examples): combine vector store with vector store index
marieaurore123 Oct 9, 2024
a15d493
docs: add and update docstrings
marieaurore123 Oct 9, 2024
46b4680
fix (examples): fix bugs in examples
marieaurore123 Oct 9, 2024
fe75da1
style: cargo fmt
marieaurore123 Oct 9, 2024
6c7ab8d
revert: revert vector store to main
marieaurore123 Oct 9, 2024
d7d2c19
Merge branch 'refactor(vector-store)/in-memeory-vector-store' into fe…
marieaurore123 Oct 9, 2024
bb712e3
docs: update emebddings builder docstrings
marieaurore123 Oct 9, 2024
efa2b65
refactor: derive macro
marieaurore123 Oct 10, 2024
01dc233
Merge branch 'main' into refactor(vector-store)/in-memeory-vector-store
marieaurore123 Oct 10, 2024
5684c90
tests: add unit tests on in memory store
marieaurore123 Oct 10, 2024
82d9f0c
fic(ci): asterix on pull request sto accomodate for epic branches
marieaurore123 Oct 10, 2024
bf7316b
fix(ci): double asterix
marieaurore123 Oct 10, 2024
8325164
feat: add error type on embeddable trait
marieaurore123 Oct 11, 2024
de022c4
refactor: move embeddings to its own module and seperate embeddable
marieaurore123 Oct 11, 2024
220d9fc
refactor: split up macro into more files, fix all imports
marieaurore123 Oct 11, 2024
1dff738
Merge branch 'main' into refactor(vector-store)/in-memeory-vector-store
marieaurore123 Oct 15, 2024
8c993dd
fix: revert logging change
marieaurore123 Oct 15, 2024
ef00b38
Merge pull request #53 from 0xPlaygrounds/refactor(vector-store)/in-m…
marieaurore123 Oct 15, 2024
f5e60f5
Merge branch 'feat/embeddings-overhaul' into feat(embeddings)/derive-…
marieaurore123 Oct 15, 2024
5a8c361
feat: handle tools with embeddingsbuilder
marieaurore123 Oct 15, 2024
dc89e54
bug(macro): fix error when embed tags missing
marieaurore123 Oct 15, 2024
ae66d08
style: cargo fmt
marieaurore123 Oct 15, 2024
4305952
fix(tests): clippy
marieaurore123 Oct 15, 2024
24e3b98
docs&revert: revert embeddable trait error type, add docstrings
marieaurore123 Oct 15, 2024
a7dbf6c
style: cargo clippy
marieaurore123 Oct 15, 2024
886ebcb
clippy(lancedb): fix unused function error
marieaurore123 Oct 15, 2024
79dea45
fix(test): remove useless assert false statement
marieaurore123 Oct 15, 2024
6362344
cleanup: split up branch into 2 branches for readability
marieaurore123 Oct 16, 2024
b5e1bf3
cleanup: revert certain changes during branch split
marieaurore123 Oct 16, 2024
7caf134
docs: revert doc string
marieaurore123 Oct 16, 2024
8739692
fix: add embedding_docs to embeddable tool
marieaurore123 Oct 16, 2024
fb979ec
refactor: use OneOrMany in Embbedable trait, make derive macro crate …
marieaurore123 Oct 16, 2024
690027c
tests: add some more tests
marieaurore123 Oct 16, 2024
cca6059
clippy: cargo clippy
marieaurore123 Oct 17, 2024
f785b8c
docs: add docstring to oneormany
marieaurore123 Oct 17, 2024
0e2ade9
fix(macro): update error handling
marieaurore123 Oct 17, 2024
a98769c
refactor: reexport EmbeddingsBuilder in rig and update imports
marieaurore123 Oct 17, 2024
aca9134
Merge branch 'feat(embeddings)/derive-macro' into feat(embeddings)/ad…
marieaurore123 Oct 17, 2024
067894c
feat: implement IntoIterator and Iterator for OneOrMany
marieaurore123 Oct 17, 2024
32bcc61
refactor: rename from methods
marieaurore123 Oct 17, 2024
564bef4
tests: fix failing tests
marieaurore123 Oct 17, 2024
04f1f3e
refactor&fix: make PR review changes
marieaurore123 Oct 17, 2024
c8f6646
fix: fix tests failing
marieaurore123 Oct 17, 2024
40f3c18
test: add test on OneOrMany
marieaurore123 Oct 17, 2024
68d88b6
style: cargo fmt
marieaurore123 Oct 17, 2024
4bc7d07
docs&fix: fix doc strings, implement iter_mut for OneOrMany
marieaurore123 Oct 17, 2024
4d2ffdb
fix: update borrow and owning of macro
marieaurore123 Oct 18, 2024
6f04225
clippy: add back print statements
marieaurore123 Oct 18, 2024
bdd98e5
Merge branch 'main' into feat(embeddings)/derive-macro
marieaurore123 Oct 18, 2024
e218cf4
Merge branch 'feat(embeddings)/derive-macro' into feat(embeddings)/ad…
marieaurore123 Oct 18, 2024
5897e22
fix: fix issues caused by merge of derive macro branch
marieaurore123 Oct 18, 2024
2477af8
fix: fix cargo toml of lancedb and mongodb
marieaurore123 Oct 18, 2024
485ad3b
refactor: use thiserror for OneOtMany::EmptyListError
marieaurore123 Oct 18, 2024
4039cf6
Merge pull request #59 from 0xPlaygrounds/feat(embeddings)/derive-macro
marieaurore123 Oct 18, 2024
23336dc
Merge branch 'feat(embeddings)/derive-macro' into feat(embeddings)/ad…
marieaurore123 Oct 18, 2024
763c364
feat: add OneOrMany to in memory vector store
marieaurore123 Oct 18, 2024
3c5e59d
style: cargo fmt
marieaurore123 Oct 18, 2024
dc332c3
fix: update embeddingsbuilder import path
marieaurore123 Oct 18, 2024
db2ec98
tests: add tests for embeddingsbuilder
marieaurore123 Oct 18, 2024
3bb2231
clippy: add is empty method
marieaurore123 Oct 18, 2024
3688b78
fix: add feature flag to examples in mongodb and lancedb crates
marieaurore123 Oct 18, 2024
db8d188
fix: move lancedb fixtures into it's own file
marieaurore123 Oct 18, 2024
803b792
fix: add dummy main function in fextures.rs for compiler
marieaurore123 Oct 18, 2024
97775d6
fix: revert fixture file, remove fixtures from cargo toml examples
marieaurore123 Oct 18, 2024
05ef716
fix: update fixture import in lancedb examples
marieaurore123 Oct 18, 2024
d75e4bb
refactor: rename D to T in embeddingsbuilder generics
marieaurore123 Oct 18, 2024
2865134
refactor: remove clone
marieaurore123 Oct 21, 2024
55e2409
PR: update builder, docstrings, and std::markers tags
marieaurore123 Oct 21, 2024
0cbc5aa
style: replace add with push
marieaurore123 Oct 21, 2024
1176e2f
fix: fix mongodb example
marieaurore123 Oct 21, 2024
f34a5dd
fix: update lancedb and mongodb doc example
marieaurore123 Oct 21, 2024
f796e12
fix: typo
marieaurore123 Oct 21, 2024
38ca0db
Merge pull request #64 from 0xPlaygrounds/feat(embeddings)/add-embedd…
marieaurore123 Oct 21, 2024
223139e
docs: add and fix docstrings and examples
marieaurore123 Oct 22, 2024
ed9e038
docs: add more doc tests
marieaurore123 Oct 22, 2024
c502ea5
feat: rename Embeddable trait to ExtractEmbeddingFields
marieaurore123 Oct 23, 2024
ebc6b81
feat: rename macro files, cargo fmt
marieaurore123 Oct 23, 2024
5c2d451
PR; update docstrings, update `add_documents_with_id` function
marieaurore123 Oct 24, 2024
55b42d8
doc: fix doc linting
marieaurore123 Oct 24, 2024
5135451
Merge pull request #72 from 0xPlaygrounds/cleanup(embeddings)/finaliz…
cvauclair Oct 24, 2024
3627441
Merge branch 'main' into feat/embeddings-overhaul
cvauclair Oct 24, 2024
b5870ce
misc: fmt
cvauclair Oct 24, 2024
8c30b54
test: fix test
cvauclair Oct 24, 2024
474695b
refactor(embeddings): embed trait definition (#89)
marieaurore123 Nov 15, 2024
5eb0014
Merge branch 'main' into feat/embeddings-overhaul
marieaurore123 Nov 23, 2024
1336633
fix/docs: fix erros from merge, cleanup embeddings docstrings
marieaurore123 Nov 23, 2024
4363671
fix: cargo clippy in examples
marieaurore123 Nov 23, 2024
2d6d7c4
Feat: small improvements + fixes + tests (#128)
cvauclair Nov 29, 2024
a041699
style: Small renaming for consistency
cvauclair Nov 29, 2024
bfc3291
docs: Improve docstrings
cvauclair Nov 29, 2024
c21f549
style: fmt
cvauclair Nov 29, 2024
1637c8a
fix: `TextEmbedder::embed` visibility
cvauclair Nov 29, 2024
ca4d9dd
docs: Simplified the `EmbeddingsBuilder` docstring example to focus o…
cvauclair Nov 29, 2024
f760a5f
style: cargo fmt
cvauclair Nov 29, 2024
f21da45
docs: Small edit to lancedb examples
cvauclair Nov 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor(embeddings): embed trait definition (#89)
* refactor: Big refactor

* refactor: refactor Embed trait, fix all imports, rename files, fix macro

* fix(embed trait): fix errors while testing

* fix(lancedb): examples

* docs: fix hyperlink

* fmt: cargo fmt

* PR; make requested changes

* fix: change visibility of struct field

* fix: failing tests

---------

Co-authored-by: Christophe <cvauclair@protonmail.com>
  • Loading branch information
marieaurore123 and cvauclair authored Nov 15, 2024
commit 474695bc0f2e60b0ac806ba2ddecbdfa53c4c475
2 changes: 1 addition & 1 deletion rig-core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -35,7 +35,7 @@ tokio-test = "0.4.4"
derive = ["dep:rig-derive"]

[[test]]
name = "extract_embedding_fields_macro"
name = "embed_macro"
required-features = ["derive"]

[[example]]
2 changes: 1 addition & 1 deletion rig-core/examples/calculator_chatbot.rs
Original file line number Diff line number Diff line change
@@ -247,7 +247,7 @@ async fn main() -> Result<(), anyhow::Error> {

let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.documents(toolset.embedabble_tools()?)?
.documents(toolset.schemas()?)?
.build()
.await?;

8 changes: 4 additions & 4 deletions rig-core/examples/rag.rs
Original file line number Diff line number Diff line change
@@ -5,14 +5,14 @@ use rig::{
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::in_memory_store::InMemoryVectorStore,
ExtractEmbeddingFields,
Embed,
};
use serde::Serialize;

// Shape of data that needs to be RAG'ed.
// A vector search needs to be performed on the definitions, so we derive the `ExtractEmbeddingFields` trait for `FakeDefinition`
// Data to be RAGged.
// A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `FakeDefinition`
// and tag that field with `#[embed]`.
#[derive(ExtractEmbeddingFields, Serialize, Clone, Debug, Eq, PartialEq, Default)]
#[derive(Embed, Serialize, Clone, Debug, Eq, PartialEq, Default)]
struct FakeDefinition {
id: String,
#[embed]
2 changes: 1 addition & 1 deletion rig-core/examples/rag_dynamic_tools.rs
Original file line number Diff line number Diff line change
@@ -156,7 +156,7 @@ async fn main() -> Result<(), anyhow::Error> {
.build();

let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.documents(toolset.embedabble_tools()?)?
.documents(toolset.schemas()?)?
.build()
.await?;

4 changes: 2 additions & 2 deletions rig-core/examples/vector_search.rs
Original file line number Diff line number Diff line change
@@ -4,13 +4,13 @@ use rig::{
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
ExtractEmbeddingFields,
Embed,
};
use serde::{Deserialize, Serialize};

// Shape of data that needs to be RAG'ed.
// The definition field will be used to generate embeddings.
#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)]
#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)]
struct FakeDefinition {
id: String,
word: String,
4 changes: 2 additions & 2 deletions rig-core/examples/vector_search_cohere.rs
Original file line number Diff line number Diff line change
@@ -4,13 +4,13 @@ use rig::{
embeddings::EmbeddingsBuilder,
providers::cohere::{Client, EMBED_ENGLISH_V3},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
ExtractEmbeddingFields,
Embed,
};
use serde::{Deserialize, Serialize};

// Shape of data that needs to be RAG'ed.
// The definition field will be used to generate embeddings.
#[derive(ExtractEmbeddingFields, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)]
#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)]
struct FakeDefinition {
id: String,
word: String,
6 changes: 3 additions & 3 deletions rig-core/rig-core-derive/src/basic.rs
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ use syn::{parse_quote, Attribute, DataStruct, Meta};

use crate::EMBED;

/// Finds and returns fields with simple #[embed] attribute tags only.
/// Finds and returns fields with simple `#[embed]` attribute tags only.
pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item = &syn::Field> {
data_struct.fields.iter().filter(|field| {
field.attrs.iter().any(|attribute| match attribute {
@@ -15,11 +15,11 @@ pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item
})
}

/// Adds bounds to where clause that force all fields tagged with #[embed] to implement the ExtractEmbeddingFields trait.
/// Adds bounds to where clause that force all fields tagged with `#[embed]` to implement the `Embed` trait.
pub(crate) fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) {
let where_clause = generics.make_where_clause();

where_clause.predicates.push(parse_quote! {
#field_type: ExtractEmbeddingFields
#field_type: Embed
});
}
Original file line number Diff line number Diff line change
@@ -17,8 +17,8 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu
let (basic_targets, basic_target_size) = data_struct.basic(generics);
let (custom_targets, custom_target_size) = data_struct.custom()?;

// If there are no fields tagged with #[embed] or #[embed(embed_with = "...")], return an empty TokenStream.
// ie. do not implement `ExtractEmbeddingFields` trait for the struct.
// If there are no fields tagged with `#[embed]` or `#[embed(embed_with = "...")]`, return an empty TokenStream.
// ie. do not implement `Embed` trait for the struct.
if basic_target_size + custom_target_size == 0 {
return Err(syn::Error::new_spanned(
name,
@@ -27,33 +27,28 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu
}

quote! {
let mut embed_targets = #basic_targets;
embed_targets.extend(#custom_targets)
#basic_targets;
#custom_targets;
}
}
_ => {
return Err(syn::Error::new_spanned(
input,
"ExtractEmbeddingFields derive macro should only be used on structs",
"Embed derive macro should only be used on structs",
))
}
};

let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

let gen = quote! {
// Note: `ExtractEmbeddingFields` trait is imported with the macro.
// Note: `Embed` trait is imported with the macro.

impl #impl_generics ExtractEmbeddingFields for #name #ty_generics #where_clause {
type Error = rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError;

fn extract_embedding_fields(&self) -> Result<rig::OneOrMany<String>, Self::Error> {
impl #impl_generics Embed for #name #ty_generics #where_clause {
fn embed(&self, embedder: &mut rig::embeddings::embed::TextEmbedder) -> Result<(), rig::embeddings::embed::EmbedError> {
#target_stream;

rig::OneOrMany::merge(
embed_targets.into_iter()
.collect::<Result<Vec<_>, _>>()?
).map_err(rig::embeddings::extract_embedding_fields::ExtractEmbeddingFieldsError::new)
Ok(())
}
}
};
@@ -62,17 +57,17 @@ pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Resu
}

trait StructParser {
// Handles fields tagged with #[embed]
// Handles fields tagged with `#[embed]`
fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize);

// Handles fields tagged with #[embed(embed_with = "...")]
// Handles fields tagged with `#[embed(embed_with = "...")]`
fn custom(&self) -> syn::Result<(TokenStream, usize)>;
}

impl StructParser for DataStruct {
fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize) {
let embed_targets = basic_embed_fields(self)
// Iterate over every field tagged with #[embed]
// Iterate over every field tagged with `#[embed]`
.map(|field| {
add_struct_bounds(generics, &field.ty);

@@ -84,50 +79,32 @@ impl StructParser for DataStruct {
})
.collect::<Vec<_>>();

if !embed_targets.is_empty() {
(
quote! {
vec![#(#embed_targets.extract_embedding_fields()),*]
},
embed_targets.len(),
)
} else {
(
quote! {
vec![]
},
0,
)
}
(
quote! {
#(#embed_targets.embed(embedder)?;)*
},
embed_targets.len(),
)
}

fn custom(&self) -> syn::Result<(TokenStream, usize)> {
let embed_targets = custom_embed_fields(self)?
// Iterate over every field tagged with #[embed(embed_with = "...")]
// Iterate over every field tagged with `#[embed(embed_with = "...")]`
.into_iter()
.map(|(field, custom_func_path)| {
let field_name = &field.ident;

quote! {
#custom_func_path(self.#field_name.clone())
#custom_func_path(embedder, self.#field_name.clone())?;
}
})
.collect::<Vec<_>>();

Ok(if !embed_targets.is_empty() {
(
quote! {
vec![#(#embed_targets),*]
},
embed_targets.len(),
)
} else {
(
quote! {
vec![]
},
0,
)
})
Ok((
quote! {
#(#embed_targets)*
},
embed_targets.len(),
))
}
}
12 changes: 6 additions & 6 deletions rig-core/rig-core-derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -4,18 +4,18 @@ use syn::{parse_macro_input, DeriveInput};

mod basic;
mod custom;
mod extract_embedding_fields;
mod embed;

pub(crate) const EMBED: &str = "embed";

// https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro
// https://doc.rust-lang.org/reference/procedural-macros.html

#[proc_macro_derive(ExtractEmbeddingFields, attributes(embed))]
/// References:
/// <https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro>
/// <https://doc.rust-lang.org/reference/procedural-macros.html>
#[proc_macro_derive(Embed, attributes(embed))]
pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
let mut input = parse_macro_input!(item as DeriveInput);

extract_embedding_fields::expand_derive_embedding(&mut input)
embed::expand_derive_embedding(&mut input)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}
2 changes: 1 addition & 1 deletion rig-core/src/completion.rs
Original file line number Diff line number Diff line change
@@ -82,7 +82,7 @@ pub enum CompletionError {

/// Error building the completion request
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + Send + Sync>),
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),

/// Error parsing the completion response
#[error("ResponseError: {0}")]
Loading