Skip to content

Commit

Permalink
Merge pull request #64 from 0xPlaygrounds/feat(embeddings)/add-embedd…
Browse files Browse the repository at this point in the history
…able-trait-to-embeddings-builder

Feat(embeddings): Add embeddable trait to embeddings builder
  • Loading branch information
marieaurore123 authored Oct 21, 2024
2 parents 4039cf6 + f796e12 commit 38ca0db
Show file tree
Hide file tree
Showing 25 changed files with 935 additions and 418 deletions.
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ resolver = "2"
members = [
"rig-core", "rig-core/rig-core-derive",
"rig-mongodb",
"rig-lancedb"
]
"rig-lancedb"]
14 changes: 13 additions & 1 deletion rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,16 @@ derive = ["dep:rig-derive"]

[[test]]
name = "embeddable_macro"
required-features = ["derive"]
required-features = ["derive"]

[[example]]
name = "rag"
required-features = ["derive"]

[[example]]
name = "vector_search"
required-features = ["derive"]

[[example]]
name = "vector_search_cohere"
required-features = ["derive"]
11 changes: 2 additions & 9 deletions rig-core/examples/calculator_chatbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use anyhow::Result;
use rig::{
cli_chatbot::cli_chatbot,
completion::ToolDefinition,
embeddings::builder::DocumentEmbeddings,
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
tool::{Tool, ToolEmbedding, ToolSet},
Expand Down Expand Up @@ -248,21 +247,15 @@ async fn main() -> Result<(), anyhow::Error> {

let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.tools(&toolset)?
.documents(toolset.embedabble_tools()?)?
.build()
.await?;

let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.map(|(tool, embedding)| (tool.name.clone(), tool, embedding))
.collect(),
)?
.index(embedding_model);
Expand Down
50 changes: 38 additions & 12 deletions rig-core/examples/rag.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
use std::env;
use std::{env, vec};

use rig::{
completion::Prompt,
embeddings::builder::DocumentEmbeddings,
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::in_memory_store::InMemoryVectorStore,
Embeddable,
};
use serde::Serialize;

// Shape of data that needs to be RAG'ed.
// The definition field will be used to generate embeddings.
#[derive(Embeddable, Clone, Debug, Serialize, Eq, PartialEq, Default)]
struct FakeDefinition {
id: String,
#[embed]
definitions: Vec<String>,
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
Expand All @@ -17,23 +27,39 @@ async fn main() -> Result<(), anyhow::Error> {
let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
.documents(vec![
FakeDefinition {
id: "doc0".to_string(),
definitions: vec![
"Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets.".to_string(),
"Definition of a *flurbo*: A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
]
},
FakeDefinition {
id: "doc1".to_string(),
definitions: vec![
"Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
"Definition of a *glarb-glarb*: A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
]
},
FakeDefinition {
id: "doc2".to_string(),
definitions: vec![
"Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
"Definition of a *linglingdong*: A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string()
]
},
])?
.build()
.await?;

let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.map(|(fake_definition, embedding_vec)| {
(fake_definition.id.clone(), fake_definition, embedding_vec)
})
.collect(),
)?
.index(embedding_model);
Expand Down
11 changes: 2 additions & 9 deletions rig-core/examples/rag_dynamic_tools.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use anyhow::Result;
use rig::{
completion::{Prompt, ToolDefinition},
embeddings::builder::DocumentEmbeddings,
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
tool::{Tool, ToolEmbedding, ToolSet},
Expand Down Expand Up @@ -157,21 +156,15 @@ async fn main() -> Result<(), anyhow::Error> {
.build();

let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.tools(&toolset)?
.documents(toolset.embedabble_tools()?)?
.build()
.await?;

let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.map(|(tool, embedding)| (tool.name.clone(), tool, embedding))
.collect(),
)?
.index(embedding_model);
Expand Down
58 changes: 44 additions & 14 deletions rig-core/examples/vector_search.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
use std::env;

use rig::{
embeddings::builder::DocumentEmbeddings,
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
Embeddable,
};
use serde::{Deserialize, Serialize};

// Shape of data that needs to be RAG'ed.
// The definition field will be used to generate embeddings.
#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)]
struct FakeDefinition {
id: String,
word: String,
#[embed]
definitions: Vec<String>,
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
Expand All @@ -16,38 +27,57 @@ async fn main() -> Result<(), anyhow::Error> {
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

let embeddings = EmbeddingsBuilder::new(model.clone())
.simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
.documents(vec![
FakeDefinition {
id: "doc0".to_string(),
word: "flurbo".to_string(),
definitions: vec![
"A green alien that lives on cold planets.".to_string(),
"A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
]
},
FakeDefinition {
id: "doc1".to_string(),
word: "glarb-glarb".to_string(),
definitions: vec![
"An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
"A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
]
},
FakeDefinition {
id: "doc2".to_string(),
word: "linglingdong".to_string(),
definitions: vec![
"A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(),
"A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string()
]
},
])?
.build()
.await?;

let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.map(|(fake_definition, embedding_vec)| {
(fake_definition.id.clone(), fake_definition, embedding_vec)
})
.collect(),
)?
.index(model);

let results = index
.top_n::<String>("What is a linglingdong?", 1)
.top_n::<FakeDefinition>("I need to buy something in a fictional universe. What type of money can I use for this?", 1)
.await?
.into_iter()
.map(|(score, id, doc)| (score, id, doc))
.map(|(score, id, doc)| (score, id, doc.word))
.collect::<Vec<_>>();

println!("Results: {:?}", results);

let id_results = index
.top_n_ids("What is a linglingdong?", 1)
.top_n_ids("I need to buy something in a fictional universe. What type of money can I use for this?", 1)
.await?
.into_iter()
.map(|(score, id)| (score, id))
Expand Down
61 changes: 47 additions & 14 deletions rig-core/examples/vector_search_cohere.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
use std::env;

use rig::{
embeddings::builder::DocumentEmbeddings,
embeddings::EmbeddingsBuilder,
providers::cohere::{Client, EMBED_ENGLISH_V3},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
Embeddable,
};
use serde::{Deserialize, Serialize};

// Shape of data that needs to be RAG'ed.
// The definition field will be used to generate embeddings.
#[derive(Embeddable, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)]
struct FakeDefinition {
id: String,
word: String,
#[embed]
definitions: Vec<String>,
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
Expand All @@ -16,33 +27,55 @@ async fn main() -> Result<(), anyhow::Error> {
let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document");
let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query");

let embeddings = EmbeddingsBuilder::new(document_model)
.simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
let embeddings = EmbeddingsBuilder::new(document_model.clone())
.documents(vec![
FakeDefinition {
id: "doc0".to_string(),
word: "flurbo".to_string(),
definitions: vec![
"A green alien that lives on cold planets.".to_string(),
"A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
]
},
FakeDefinition {
id: "doc1".to_string(),
word: "glarb-glarb".to_string(),
definitions: vec![
"An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
"A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
]
},
FakeDefinition {
id: "doc2".to_string(),
word: "linglingdong".to_string(),
definitions: vec![
"A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(),
"A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string()
]
},
])?
.build()
.await?;

let index = InMemoryVectorStore::default()
.add_documents(
embeddings
.into_iter()
.map(
|DocumentEmbeddings {
id,
document,
embeddings,
}| { (id, document, embeddings) },
)
.map(|(fake_definition, embedding_vec)| {
(fake_definition.id.clone(), fake_definition, embedding_vec)
})
.collect(),
)?
.index(search_model);

let results = index
.top_n::<String>("What is a linglingdong?", 1)
.top_n::<FakeDefinition>(
"Which instrument is found in the Nebulon Mountain Ranges?",
1,
)
.await?
.into_iter()
.map(|(score, id, doc)| (score, id, doc))
.map(|(score, id, doc)| (score, id, doc.word))
.collect::<Vec<_>>();

println!("Results: {:?}", results);
Expand Down
Loading

0 comments on commit 38ca0db

Please sign in to comment.