diff --git a/Cargo.lock b/Cargo.lock index 76e03fd9..67c4fc9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "ahash" version = "0.8.11" @@ -30,6 +36,15 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "android-tzdata" version = "0.1.1" @@ -78,7 +93,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.7.3", "object", "rustc-demangle", ] @@ -216,6 +231,40 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + [[package]] name = "crypto-common" version = "0.1.6" @@ -226,6 +275,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + [[package]] name = "darling" version = "0.13.4" @@ -317,6 +387,12 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + [[package]] name = "encoding_rs" version = "0.8.34" @@ -366,6 +442,16 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6" +[[package]] +name = "flate2" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" +dependencies = [ + "crc32fast", + "miniz_oxide 0.8.0", +] + [[package]] name = "fnv" version = "1.0.7" @@ -402,6 +488,16 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" +[[package]] +name = "futf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df420e2e84819663797d1ec6544b13c5be84629e7bb00dc960d6917db2987843" +dependencies = [ + "mac", + "new_debug_unreachable", +] + [[package]] name = "futures" version = "0.3.30" @@ -581,6 +677,20 @@ dependencies = [ "winapi", ] +[[package]] +name = "html5ever" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e15626aaf9c351bc696217cbe29cb9b5e86c43f8a46b5e2f5c6c5cf7cb904ce" +dependencies = [ + "log", + "mac", + "markup5ever", + "proc-macro2", + "quote", + "syn 2.0.65", +] + [[package]] name = "http" version = "0.2.12" @@ -785,6 +895,26 @@ version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +[[package]] +name = "lopdf" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5c8ecfc6c72051981c0459f75ccc585e7ff67c70829560cda8e647882a9abff" +dependencies = [ + "chrono", + "encoding_rs", + "flate2", + "indexmap", + "itoa", + "log", + "md-5", + "nom", + "rangemap", + "rayon", + "time", + "weezl", +] + [[package]] name = "lru-cache" version = "0.1.2" @@ -794,6 +924,37 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "mac" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" + +[[package]] +name = "markdown" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef3aab6a1d529b112695f72beec5ee80e729cb45af58663ec902c8fac764ecdd" +dependencies = [ + "lazy_static", + "pipeline", + "regex", +] + +[[package]] +name = "markup5ever" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82c88c6129bd24319e62a0359cb6b958fa7e8be6e19bb1663bc396b90883aca5" +dependencies = [ + "log", + "phf", + "phf_codegen", + "string_cache", + "string_cache_codegen", + "tendril", +] + [[package]] name = "match_cfg" version = "0.1.0" @@ -828,6 +989,12 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.7.3" @@ -837,6 +1004,15 @@ dependencies = [ "adler", ] +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + [[package]] name = "mio" version = "0.8.11" @@ -913,6 +1089,22 @@ dependencies = [ "tempfile", ] +[[package]] +name = "new_debug_unreachable" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1060,6 +1252,63 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "phf" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +dependencies = [ + "phf_shared 0.11.2", +] + +[[package]] +name = "phf_codegen" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +dependencies = [ + "phf_generator 0.11.2", + "phf_shared 0.11.2", +] + +[[package]] +name = "phf_generator" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d5285893bb5eb82e6aaf5d59ee909a06a16737a8970984dd7746ba9283498d6" +dependencies = [ + "phf_shared 0.10.0", + "rand", +] + +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared 0.11.2", + "rand", +] + +[[package]] +name = "phf_shared" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" +dependencies = [ + "siphasher", +] + +[[package]] +name = "phf_shared" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project-lite" version = "0.2.14" @@ -1072,6 +1321,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pipeline" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d15b6607fa632996eb8a17c9041cb6071cb75ac057abd45dece578723ea8c7c0" + [[package]] name = "pkg-config" version = "0.3.30" @@ -1090,6 +1345,12 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "precomputed-hash" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" + [[package]] name = "proc-macro2" version = "1.0.83" @@ -1150,6 +1411,32 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rangemap" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60fcc7d6849342eff22c4350c8b9a989ee8ceabc4b481253e8946b9fe83d684" + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.1" @@ -1159,6 +1446,35 @@ dependencies = [ "bitflags 2.5.0", ] +[[package]] +name = "regex" +version = "1.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" + [[package]] name = "reqwest" version = "0.11.27" @@ -1214,7 +1530,12 @@ name = "rig-core" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", + "csv", "futures", + "html5ever", + "lopdf", + "markdown", "ordered-float", "reqwest", "schemars", @@ -1224,6 +1545,7 @@ dependencies = [ "tokio", "tracing", "tracing-subscriber", + "walkdir", ] [[package]] @@ -1339,6 +1661,15 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.23" @@ -1558,6 +1889,12 @@ dependencies = [ "libc", ] +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + [[package]] name = "slab" version = "0.4.9" @@ -1599,6 +1936,32 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "string_cache" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" +dependencies = [ + "new_debug_unreachable", + "once_cell", + "parking_lot", + "phf_shared 0.10.0", + "precomputed-hash", + "serde", +] + +[[package]] +name = "string_cache_codegen" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bb30289b722be4ff74a408c3cc27edeaad656e06cb1fe8fa9231fa59c728988" +dependencies = [ + "phf_generator 0.10.0", + "phf_shared 0.10.0", + "proc-macro2", + "quote", +] + [[package]] name = "stringprep" version = "0.1.4" @@ -1695,6 +2058,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "tendril" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d24a120c5fc464a3458240ee02c299ebcb9d67b5249c8848b09d639dca8d7bb0" +dependencies = [ + "futf", + "mac", + "utf-8", +] + [[package]] name = "thiserror" version = "1.0.61" @@ -2004,6 +2378,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "uuid" version = "1.8.0" @@ -2032,6 +2412,16 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -2129,6 +2519,12 @@ version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" +[[package]] +name = "weezl" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" + [[package]] name = "widestring" version = "1.1.0" @@ -2151,6 +2547,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index dde7027e..237418c1 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -11,8 +11,6 @@ repository = "https://github.com/0xPlaygrounds/rig" name="rig" path="src/lib.rs" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] reqwest = { version = "0.11.22", features = ["json"] } serde = { version = "1.0.193", features = ["derive"] } @@ -22,8 +20,15 @@ futures = "0.3.29" ordered-float = "4.2.0" schemars = "0.8.16" thiserror = "1.0.61" +async-trait = "0.1.68" +csv = "1.3.0" +lopdf = "0.34.0" +html5ever = "0.29.0" +markdown = "0.3.0" +tokio = { version = "1.34.0", features = ["fs", "io-util"] } +anyhow = "1.0.75" +walkdir = "2.5.0" [dev-dependencies] -anyhow = "1.0.75" tokio = { version = "1.34.0", features = ["full"] } -tracing-subscriber = "0.3.18" +tracing-subscriber = "0.3.18" \ No newline at end of file diff --git a/rig-core/examples/document_loaders.rs b/rig-core/examples/document_loaders.rs new file mode 100644 index 00000000..40a41712 --- /dev/null +++ b/rig-core/examples/document_loaders.rs @@ -0,0 +1,87 @@ +use rig::{ + completion::Prompt, + document_loaders::PdfLoader, + embeddings::EmbeddingsBuilder, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, +}; +use std::env; +use std::path::PathBuf; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Print current working directory + println!("Current working directory: {:?}", env::current_dir()?); + + // Path to the PDF file + let pdf_path = PathBuf::from("rig-core/examples/sample_data/moores_law_for_everything.pdf"); + + // Print absolute path + println!( + "Attempting to access file at: {:?}", + pdf_path.canonicalize()? + ); + + // Check if the file exists + if !pdf_path.exists() { + eprintln!("Error: The file {} does not exist.", pdf_path.display()); + return Ok(()); + } + + println!("File found successfully!"); + + // Initialize OpenAI client + let openai = Client::from_env(); + let embedding_model = openai.embedding_model(TEXT_EMBEDDING_ADA_002); + + // Create vector store + let mut vector_store = InMemoryVectorStore::default(); + + // Build embeddings + let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) + .add_loader(PdfLoader::new(pdf_path.to_str().unwrap())) + .build() + .await?; + + println!( + "Embeddings created successfully. Count: {}", + embeddings.len() + ); + for emb in &embeddings { + // println!("Document ID: {}", emb.id); + // println!("Document Content: {:?}", emb.document); + println!("Number of embeddings: {}", emb.embeddings.len()); + println!( + "First embedding vector length: {}", + emb.embeddings.first().map_or(0, |e| e.vec.len()) + ); + println!("--------------------"); + } + + // Add documents to vector store + vector_store.add_documents(embeddings).await?; + + // Create vector store index + let index = vector_store.index(embedding_model); + + // Create RAG agent + let rag_agent = openai + .agent("gpt-4") + .preamble( + " + You are a knowledgeable assistant. + Use the information provided to you to answer questions. + ", + ) + .dynamic_context(5, index) + .build(); + + // Prompt the agent and print the response + let response = rag_agent + .prompt("give me a summary of the document.") + .await?; + + println!("Agent Response:\n{}", response); + + Ok(()) +} diff --git a/rig-core/examples/rag_with_csv.rs b/rig-core/examples/rag_with_csv.rs new file mode 100644 index 00000000..82637eb8 --- /dev/null +++ b/rig-core/examples/rag_with_csv.rs @@ -0,0 +1,85 @@ +use rig::{ + completion::Prompt, + document_loaders::CsvLoader, + embeddings::EmbeddingsBuilder, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, +}; +use std::env; +use std::path::PathBuf; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Print current working directory + println!("Current working directory: {:?}", env::current_dir()?); + + // Path to the CSV file + let csv_path = PathBuf::from("rig-core/examples/sample_data/top_rated_movies.csv"); + + // Print absolute path + println!( + "Attempting to access file at: {:?}", + csv_path.canonicalize()? + ); + + // Check if the file exists + if !csv_path.exists() { + eprintln!("Error: The file {} does not exist.", csv_path.display()); + return Ok(()); + } + + println!("File found successfully!"); + + // Initialize OpenAI client + let openai = Client::from_env(); + let embedding_model = openai.embedding_model(TEXT_EMBEDDING_ADA_002); + + // Create vector store + let mut vector_store = InMemoryVectorStore::default(); + + // Build embeddings + let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) + .add_loader(CsvLoader::new(csv_path.to_str().unwrap())) + .build() + .await?; + + println!( + "Embeddings created successfully. Count: {}", + embeddings.len() + ); + for emb in &embeddings { + println!("Number of embeddings: {}", emb.embeddings.len()); + println!( + "First embedding vector length: {}", + emb.embeddings.first().map_or(0, |e| e.vec.len()) + ); + println!("--------------------"); + } + + // Add documents to vector store + vector_store.add_documents(embeddings).await?; + + // Create vector store index + let index = vector_store.index(embedding_model); + + // Create RAG agent + let rag_agent = openai + .agent("gpt-4") + .preamble( + " + You are a knowledgeable assistant. + Use the information provided to you to answer questions about the CSV data. + ", + ) + .dynamic_context(5, index) + .build(); + + // Prompt the agent and print the response + let response = rag_agent + .prompt("Give me a summary of the CSV data.") + .await?; + + println!("Agent Response:\n{}", response); + + Ok(()) +} diff --git a/rig-core/examples/sample_data/moores_law_for_everything.pdf b/rig-core/examples/sample_data/moores_law_for_everything.pdf new file mode 100644 index 00000000..d67afc51 Binary files /dev/null and b/rig-core/examples/sample_data/moores_law_for_everything.pdf differ diff --git a/rig-core/examples/sample_data/top_rated_movies.csv b/rig-core/examples/sample_data/top_rated_movies.csv new file mode 100644 index 00000000..d5c54344 --- /dev/null +++ b/rig-core/examples/sample_data/top_rated_movies.csv @@ -0,0 +1,110 @@ +popularity,release_date,title,vote_average +174.522,9/23/1994,The Shawshank Redemption,8.706 +165.677,3/14/1972,The Godfather,8.69 +174.522,9/23/1994,The Shawshank Redemption,8.706 +165.677,3/14/1972,The Godfather,8.69 +47.916,12/20/1997,Life Is Beautiful,8.449 +197.569,11/5/2014,Interstellar,8.44 +42.629,10/13/2023,TAYLOR SWIFT | THE ERAS TOUR,8.388 +21.39,11/19/2020,Gabriel's Inferno: Part III,8.4 +69.527,7/3/1985,Back to the Future,8.318 +49.507,6/2/1989,Dead Poets Society,8.312 +18.452,10/28/1998,The Legend of 1900,8.266 +14.67,8/22/2020,Given,8.3 +173.556,12/7/2022,Puss in Boots: The Last Wish,8.227 +29.933,10/26/2020,Wolfwalkers,8.22 +135.05,10/7/2016,Hacksaw Ridge,8.198 +46.126,12/19/1971,A Clockwork Orange,8.2 +20.423,1/28/2005,Innocent Voices,8.174 +27.808,11/3/1953,Tokyo Story,8.2 +83.731,3/18/2021,Zack Snyder's Justice League,8.148 +29.312,10/25/2019,Better Days,8.1 +89.436,7/3/1991,Terminator 2: Judgment Day,8.119 +14.482,7/6/1944,Double Indemnity,8.1 +55.332,9/19/2013,Prisoners,8.098 +14.629,3/31/1954,Sansho the Bailiff,8.098 +11.283,11/24/2021,Far from the Tree,8.074 +67.735,9/16/2005,Pride & Prejudice,8.075 +38.208,2/26/2014,The Grand Budapest Hotel,8.1 +42.511,12/3/2019,How to Train Your Dragon: Homecoming,8.048 +13.881,10/27/2022,Beyond the Universe,8.027 +26.21,10/18/2019,Jojo Rabbit,8.024 +12.038,11/27/2020,Black Beauty,8 +37.023,6/8/2009,Hachi: A Dog's Tale,8.008 +5.126,6/1/2017,In a Heartbeat,7.995 +33.577,12/23/2009,3 Idiots,7.995 +137.914,5/3/2023,Guardians of the Galaxy Vol. 3,8 +18.167,1/31/2009,Love Exposure,8 +38.425,8/6/1999,The Sixth Sense,7.957 +21.449,12/15/2004,Million Dollar Baby,7.957 +56.873,6/13/2007,No Country for Old Men,7.944 +35.17,10/18/2013,12 Years a Slave,7.942 +12.423,1/21/2022,My Father's Violin,7.926 +20.658,6/22/1954,On the Waterfront,7.9 +10.913,8/1/1997,Children of Heaven,7.914 +27.948,12/21/2016,Dangal,7.913 +30.657,1/15/2021,Wish Dragon,7.902 +28.602,11/4/2016,A Street Cat Named Bob,7.905 +11.461,6/10/2008,La Maison en Petits Cubes,7.893 +52.355,12/20/2017,The Greatest Showman,7.891 +19.146,9/20/2000,Yi Yi,7.875 +19.15,9/1/2000,Dancer in the Dark,7.875 +14.602,12/21/2011,My Way,7.858 +18.928,9/16/2004,Downfall,7.858 +13.954,3/15/1940,The Grapes of Wrath,7.8 +43.665,3/30/1990,Dances with Wolves,7.847 +16.392,5/1/1983,Nostalgia,7.838 +28.293,12/22/1960,Two Women,7.837 +42.994,12/3/2022,The First Slam Dunk,7.8 +92.789,6/21/2007,Ratatouille,7.824 +32.022,3/20/1972,Solaris,7.8 +22.13,6/16/2004,Before Sunset,7.818 +33.161,10/23/2009,Fantastic Mr. Fox,7.8 +141.3,7/9/2003,Pirates of the Caribbean: The Curse of the Black Pearl,7.804 +10.15,8/21/1988,A Short Film About Love,7.794 +45.622,9/20/2012,The Perks of Being a Wallflower,7.793 +16.517,8/31/2000,Nine Queens,7.784 +12.185,12/17/1993,The Wrong Trousers,7.784 +20.295,12/4/1990,Awakenings,7.768 +11.949,5/28/2009,Partly Cloudy,7.767 +18.816,11/18/1974,A Woman Under the Influence,7.8 +18.124,2/14/2008,The Chaser,7.758 +86.727,2/11/2016,Zootopia,7.749 +12.807,3/19/1980,The King and the Mockingbird,7.749 +20.424,9/28/2019,Marriage Story,7.738 +56.129,7/13/2022,The Killer,7.7 +16.337,11/20/2020,Sound of Metal,7.727 +14.088,9/20/1962,Vivre Sa Vie,7.727 +18.88,12/22/2004,Hotel Rwanda,7.7 +24.783,2/18/2017,Sword Art Online: The Movie – Ordinal Scale,7.718 +62.912,6/10/2005,Batman Begins,7.709 +28.113,12/9/1965,A Charlie Brown Christmas,7.707 +8.592,4/7/1966,For Love and Gold,7.698 +15.608,3/31/2011,The Turin Horse,7.7 +61.64,9/30/2015,The Martian,7.687 +13.588,1/8/2014,Boys,7.687 +12.97,6/1/1998,"Black Cat, White Cat",7.68 +12.521,9/25/1961,The Hustler,7.68 +14.101,6/27/1951,Strangers on a Train,7.671 +16.624,5/23/2019,The Traitor,7.672 +29.525,3/6/1996,Primal Fear,7.661 +26.983,3/31/2016,Hunt for the Wilderpeople,7.7 +13.112,9/2/1949,White Heat,7.6 +46.7,7/24/2020,The Kissing Booth 2,7.648 +20.064,9/28/2022,Entergalactic,7.641 +78.458,2/22/2024,Exhuma,7.64 +56.569,1/19/2017,A Dog's Purpose,7.632 +47.331,9/26/2008,Fireproof,7.632 +18.036,10/19/1970,Le Cercle Rouge,7.623 +65.228,2/24/2017,Get Out,7.623 +38.806,4/9/2015,The Longest Ride,7.614 +20.63,3/30/2005,Mysterious Skin,7.615 +24.423,3/24/1989,The Killer,7.6 +13.757,6/19/1969,The Wild Bunch,7.607 +25.542,10/24/2008,Changeling,7.6 +25.864,12/20/1991,JFK,7.6 +23.749,6/19/2020,Feel the Beat,7.59 +45.792,3/30/1999,10 Things I Hate About You,7.6 +24.649,8/24/2018,Searching,7.583 +79.872,12/15/2009,Avatar,7.583 +17.303,6/19/2014,What We Do in the Shadows,7.575 \ No newline at end of file diff --git a/rig-core/src/document_loaders/csv.rs b/rig-core/src/document_loaders/csv.rs new file mode 100644 index 00000000..ea91e7d8 --- /dev/null +++ b/rig-core/src/document_loaders/csv.rs @@ -0,0 +1,49 @@ +use async_trait::async_trait; +use csv::Reader; +use serde_json::json; +use std::error::Error as StdError; +use tokio::fs::File; +use tokio::io::AsyncReadExt; + +use super::DocumentLoader; +use crate::embeddings::DocumentEmbeddings; + +pub struct CsvLoader { + path: String, +} + +impl CsvLoader { + pub fn new(path: &str) -> Self { + Self { + path: path.to_string(), + } + } +} + +#[async_trait] +impl DocumentLoader for CsvLoader { + async fn load(&self) -> Result, Box> { + let mut file = File::open(&self.path).await?; + let mut contents = String::new(); + file.read_to_string(&mut contents).await?; + + let mut reader = Reader::from_reader(contents.as_bytes()); + let headers: Vec = reader.headers()?.iter().map(|h| h.to_string()).collect(); + + let mut csv_content = String::new(); + + for result in reader.records() { + let record = result?; + for (i, field) in record.iter().enumerate() { + csv_content.push_str(&format!("{}: {}\n", headers[i], field)); + } + csv_content.push('\n'); // Changed from push_str("\n") to push('\n') + } + + Ok(vec![DocumentEmbeddings { + id: self.path.clone(), + document: json!({"text": csv_content}), + embeddings: vec![], + }]) + } +} diff --git a/rig-core/src/document_loaders/mod.rs b/rig-core/src/document_loaders/mod.rs new file mode 100644 index 00000000..b25b15f5 --- /dev/null +++ b/rig-core/src/document_loaders/mod.rs @@ -0,0 +1,28 @@ +//! This module contains the implementation of document loaders for various file formats. +//! Currently, it includes loaders for CSV and PDF files. + +mod csv; +// mod directory; +// mod html; +// mod json; +// mod markdown; +// mod office; +mod pdf; + +use crate::embeddings::DocumentEmbeddings; +use async_trait::async_trait; +use std::error::Error as StdError; + +#[async_trait] +pub trait DocumentLoader { + /// Asynchronously loads the document and returns a vector of document embeddings. + async fn load(&self) -> Result, Box>; +} + +pub use csv::CsvLoader; +// pub use directory::DirectoryLoader; +// pub use html::HtmlLoader; +// pub use json::JsonLoader; +// pub use markdown::MarkdownLoader; +// pub use office::OfficeLoader; +pub use pdf::PdfLoader; diff --git a/rig-core/src/document_loaders/pdf.rs b/rig-core/src/document_loaders/pdf.rs new file mode 100644 index 00000000..6d073af1 --- /dev/null +++ b/rig-core/src/document_loaders/pdf.rs @@ -0,0 +1,49 @@ +// Import necessary dependencies +use super::DocumentLoader; +use crate::embeddings::DocumentEmbeddings; +use async_trait::async_trait; +use lopdf::Document; +use serde_json::json; + +// Define a struct for loading PDF documents +pub struct PdfLoader { + path: String, +} + +impl PdfLoader { + // Implement a constructor for the PdfLoader struct + pub fn new(path: &str) -> Self { + Self { + path: path.to_string(), + } + } +} + +#[async_trait] +impl DocumentLoader for PdfLoader { + // Implement the load function for the DocumentLoader trait + async fn load( + &self, + ) -> Result, Box> { + // Load the PDF document from the specified path + let doc = Document::load(&self.path)?; + + // Extract text from each page of the PDF document + let mut text = String::new(); + for page in doc.get_pages() { + if let Ok(content) = doc.extract_text(&[page.0]) { + text.push_str(&content); + } + } + + // Print the extracted text for debugging purposes + println!("Extracted text from PDF: {}", text); + + // Create a DocumentEmbeddings object with the extracted text + Ok(vec![DocumentEmbeddings { + id: self.path.clone(), + document: json!({"text": text}), + embeddings: vec![], // Empty vector, embeddings will be generated later + }]) + } +} diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index 2d40bbc5..a9b0f1b2 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -43,6 +43,7 @@ use std::{cmp::max, collections::HashMap}; use futures::{stream, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; +use crate::document_loaders::DocumentLoader; use crate::tool::{ToolEmbedding, ToolSet, ToolType}; #[derive(Debug, thiserror::Error)] @@ -66,6 +67,10 @@ pub enum EmbeddingError { /// Error returned by the embedding model provider #[error("ProviderError: {0}")] ProviderError(String), + + /// Error loading documents + #[error("LoaderError: {0}")] + LoaderError(String), } /// Trait for embedding models that can generate embeddings for documents. @@ -153,6 +158,7 @@ type Embeddings = Vec; pub struct EmbeddingsBuilder { model: M, documents: Vec<(String, serde_json::Value, Vec)>, + loaders: Vec>, // New field for document loaders } impl EmbeddingsBuilder { @@ -161,6 +167,7 @@ impl EmbeddingsBuilder { Self { model, documents: vec![], + loaders: vec![], // Initialize loaders } } @@ -275,11 +282,43 @@ impl EmbeddingsBuilder { self } + /// Add a new document loader + pub fn add_loader(mut self, loader: impl DocumentLoader + 'static) -> Self { + self.loaders.push(Box::new(loader)); // Add loader to the collection + self + } + /// Generate the embeddings for the given documents pub async fn build(self) -> Result { // Create a temporary store for the documents - let documents_map = self - .documents + let mut all_documents = self.documents; + + // Load documents from loaders and merge them with existing ones + for loader in self.loaders { + let loaded_docs = loader + .load() + .await + .map_err(|e| EmbeddingError::LoaderError(e.to_string()))?; + for doc in loaded_docs { + // Extract the text content from the document + let text = match &doc.document { + serde_json::Value::String(s) => s.clone(), + serde_json::Value::Object(obj) => obj + .get("text") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(), + _ => { + return Err(EmbeddingError::DocumentError( + "Invalid document format".to_string(), + )) + } + }; + all_documents.push((doc.id, doc.document, vec![text])); + } + } + + let documents_map = all_documents .into_iter() .map(|(id, document, docs)| (id, (document, docs))) .collect::>(); @@ -289,7 +328,7 @@ impl EmbeddingsBuilder { .flat_map(|(id, (_, docs))| { stream::iter(docs.iter().map(|doc| (id.clone(), doc.clone()))) }) - // Chunk them into N (the emebdding API limit per request) + // Chunk them into N (the embedding API limit per request) .chunks(M::MAX_DOCUMENTS) // Generate the embeddings .map(|docs| async { diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index c480272f..3ef157ba 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -77,6 +77,7 @@ pub mod agent; pub mod cli_chatbot; pub mod completion; +pub mod document_loaders; pub mod embeddings; pub mod extractor; pub mod json_utils;