From 3605437439a92f5bfc00d691a14a7a90095608d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20B=C3=A9langer?= Date: Fri, 11 Oct 2024 19:54:35 +0200 Subject: [PATCH 01/18] feat(provider-gemini): add gemini API client --- rig-core/src/providers/gemini/client.rs | 117 ++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 rig-core/src/providers/gemini/client.rs diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs new file mode 100644 index 00000000..d8f28088 --- /dev/null +++ b/rig-core/src/providers/gemini/client.rs @@ -0,0 +1,117 @@ +use crate::embeddings::{self}; +use serde::Deserialize; + +use super::{completion::CompletionModel, embedding::EmbeddingModel}; + +// ================================================================ +// Google Gemini Client +// ================================================================ +const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com"; + +#[derive(Clone)] +pub struct Client { + base_url: String, + api_key: String, + http_client: reqwest::Client, +} + +impl Client { + pub fn new(api_key: &str) -> Self { + Self::from_url(api_key, GEMINI_API_BASE_URL) + } + fn from_url(api_key: &str, base_url: &str) -> Self { + Self { + base_url: base_url.to_string(), + api_key: api_key.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::CONTENT_TYPE, + "application/json".parse().unwrap(), + ); + headers + }) + .build() + .expect("Gemini reqwest client should build"), + } + } + pub fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/"); + + tracing::info!("POST {}", url); + self.http_client.post(url) + } + + /// Create an embedding model with the given name. + /// Note: default embedding dimension of 0 will be used if model is not known. + /// If this is the case, it's better to use function `embedding_model_with_ndims` + /// + /// # Example + /// ``` + /// use rig::providers::gemini::{Client, self}; + /// + /// // Initialize the Google Gemini client + /// let gemini = Client::new("your-google-gemini-api-key"); + /// + /// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001); + /// ``` + pub fn embedding_model(&self, model: &str) -> EmbeddingModel { + EmbeddingModel::new(self.clone(), model, None) + } + + /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model. + /// + /// # Example + /// ``` + /// use rig::providers::gemini::{Client, self}; + /// + /// // Initialize the Google Gemini client + /// let gemini = Client::new("your-google-gemini-api-key"); + /// + /// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024); + /// ``` + pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel { + EmbeddingModel::new(self.clone(), model, Some(ndims)) + } + + /// Create an embedding builder with the given embedding model. + /// + /// # Example + /// ``` + /// use rig::providers::gemini::{Client, self}; + /// + /// // Initialize the Google Gemini client + /// let gemini = Client::new("your-google-gemini-api-key"); + /// + /// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001) + /// .simple_document("doc0", "Hello, world!") + /// .simple_document("doc1", "Goodbye, world!") + /// .build() + /// .await + /// .expect("Failed to embed documents"); + /// ``` + pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder { + embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) + } + + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } + + // pub fn agent(&self, model: &str) -> AgentBuilder { + // AgentBuilder::new(self.completion_model(model)) + // } +} + +#[derive(Debug, Deserialize)] +pub struct ApiErrorResponse { + pub message: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum ApiResponse { + Ok(T), + Err(ApiErrorResponse), +} From 3aed9804107252cb49b94ab414d093d3cb045927 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20B=C3=A9langer?= Date: Fri, 11 Oct 2024 19:55:02 +0200 Subject: [PATCH 02/18] feat(provider-gemini): add gemini support for basic completion --- rig-core/src/providers/gemini/completion.rs | 374 ++++++++++++++++++++ 1 file changed, 374 insertions(+) create mode 100644 rig-core/src/providers/gemini/completion.rs diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs new file mode 100644 index 00000000..5fd0bb89 --- /dev/null +++ b/rig-core/src/providers/gemini/completion.rs @@ -0,0 +1,374 @@ +// ================================================================ +// Google Gemini Completion API +// ================================================================ +// +// https://ai.google.dev/api/generate-conten + +/// `gemini-1.5-flash` completion model +pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash"; +/// `gemini-1.5-pro` completion model +pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro"; +/// `gemini-1.5-pro-8b` completion model +pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b"; +/// `gemini-1.0-pro` completion model +pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro"; + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::{completion::{self, CompletionError, CompletionRequest}, providers::gemini::client::ApiResponse}; + +use super::Client; + + +// +// Gemini API Response Types +// + +// Define the struct for the GenerateContentResponse +#[derive(Debug, Deserialize)] +pub struct GenerateContentResponse { + pub candidates: Vec, + pub prompt_feedback: Option, + pub usage_metadata: Option, +} + +// Define the struct for a Candidate +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ContentCandidate { + pub content: Content, + pub finish_reason: Option, + pub safety_ratings: Option>, + pub citation_metadata: Option, + pub token_count: Option, + pub avg_logprobs: Option, + pub logprobs_result: Option, + pub index: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Content { + pub parts: Vec, + pub role: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Part { + pub text: String, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct SafetyRating { + pub category: HarmCategory, + pub probability: HarmProbability, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum HarmProbability { + HarmProbabilityUnspecified, + Negligible, + Low, + Medium, + High, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum HarmCategory { + HarmCategoryUnspecified, + HarmCategoryDerogatory, + HarmCategoryToxicity, + HarmCategoryViolence, + HarmCategorySexually, + HarmCategoryMedical, + HarmCategoryDangerous, + HarmCategoryHarassment, + HarmCategoryHateSpeech, + HarmCategorySexuallyExplicit, + HarmCategoryDangerousContent, + HarmCategoryCivicIntegrity, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetadata { + pub prompt_token_count: i32, + pub cached_content_token_count: i32, + pub candidates_token_count: i32, + pub total_token_count: i32, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptFeedback { + pub block_reason: Option, + pub safety_ratings: Option>, +} + +#[derive(Debug, Deserialize)] +pub enum BlockReason { + BlockReasonUnspecified, + Safety, + Other, + Blocklist, + ProhibitedContent, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum FinishReason { + FinishReasonUnspecified, + Stop, + MaxTokens, + Safety, + Recitation, + Language, + Other, + Blocklist, + ProhibitedContent, +} + +#[derive(Debug, Deserialize)] +pub struct CitationMetadata { + pub citation_sources: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct CitationSource { + pub uri: Option, + pub start_index: Option, + pub end_index: Option, + pub license: Option, +} + +#[derive(Debug, Deserialize)] +pub struct LogprobsResult { + pub top_candidate: Vec, + pub chosen_candidate: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct TopCandidate { + pub candidates: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LogProbCandidate { + pub token: String, + pub token_id: String, + pub log_probability: f64, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct GenerationConfig { + pub stop_sequences: Option>, + pub response_mime_type: Option, + pub response_schema: Option, + pub candidate_count: Option, + pub max_output_tokens: Option, + pub temperature: Option, + pub top_p: Option, + pub top_k: Option, + pub presence_penalty: Option, + pub frequency_penalty: Option, + pub response_logprobs: Option, + pub logprobs: Option, +} + +/// The Schema object allows the definition of input and output data types. These types can be objects, but also primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object. +/// https://ai.google.dev/api/caching#Schema +#[derive(Debug, Deserialize, Serialize)] +pub struct Schema { + pub r#type: String, + pub format: Option, + pub description: Option, + pub nullable: Option, + pub r#enum: Option>, + pub max_items: Option, + pub min_items: Option, + pub properties: Option>, + pub required: Option>, + pub items: Option>, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct GenerateContentRequest { + pub contents: Vec, + pub tools: Option>, + pub tool_config: Option, + pub generation_config: Option, + pub safety_settings: Option>, + pub system_instruction: Option, + // cachedContent: Optional +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Tool { + pub function_declaration: FunctionDeclaration, + pub code_execution: Option, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionDeclaration { + pub name: String, + pub description: String, + pub parameters: Option>, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolConfig { + pub schema: Option, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CodeExecution {} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SafetySetting { + pub category: HarmCategory, + pub threshold: HarmBlockThreshold, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum HarmBlockThreshold { + HarmBlockThresholdUnspecified, + BlockLowAndAbove, + BlockMediumAndAbove, + BlockOnlyHigh, + BlockNone, + Off +} + +// ================================================================= +// Rig Implementation Types +// ================================================================= + +impl From for Tool { + fn from(tool: completion::ToolDefinition) -> Self { + Self { + function_declaration: FunctionDeclaration { + name: tool.name, + description: tool.description, + parameters: None // tool.parameters, TODO: Map Gemini + }, + code_execution: None, + } + } +} + +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } +} + + +impl TryFrom for completion::CompletionResponse { + type Error = CompletionError; + + fn try_from(response: GenerateContentResponse) -> Result { + match response.candidates.as_slice() { + [ContentCandidate { content, .. }, ..] => { + Ok(completion::CompletionResponse { + choice: completion::ModelChoice::Message(content.parts.first().unwrap().text.clone()), + raw_response: response, + }) + } + _ => Err(CompletionError::ResponseError("No candidates found in response".into())), + } + } +} + + +impl completion::CompletionModel for CompletionModel { + type Response = GenerateContentResponse; + + async fn completion( + &self, + mut completion_request: CompletionRequest, + ) -> Result, CompletionError> { + + // QUESTION: Why do Anthropic/openAi implementation differ here? OpenAI adds the preamble but Anthropic does not. + + let mut full_history = if let Some(preamble) = &completion_request.preamble { + vec![completion::Message { + role: "system".into(), + content: preamble.clone(), + }] + } else { vec![] }; + + full_history.append(&mut completion_request.chat_history); + + let prompt_with_context = completion_request.prompt_with_context(); + + full_history.push(completion::Message { + role: "user".into(), + content: prompt_with_context, + }); + + let request = GenerateContentRequest { + contents: full_history.into_iter().map(|msg| Content { + parts: vec![Part { text: msg.content }], + role: match msg.role.as_str() { + "system" => Some("model".to_string()), + "user" => Some("user".to_string()), + "assistant" => Some("model".to_string()), + _ => None, + }, + }).collect(), + // QUESTION: How to handle API config specifics? + generation_config: Some(GenerationConfig { + temperature: completion_request.temperature, + max_output_tokens: completion_request.max_tokens, + top_p: None, + top_k: None, + candidate_count: None, + stop_sequences: None, + response_mime_type: None, + response_schema: None, + presence_penalty: None, + frequency_penalty: None, + response_logprobs: None, + logprobs: None, + }), + safety_settings: None, + tools: Some(completion_request.tools.into_iter().map(Tool::from).collect()), + tool_config: None, + system_instruction: None, + }; + + let response = self.client.post(&format!("/v1beta/models/{}:generateContent", self.model)) + .json(&request) + .send() + .await? + .error_for_status()? + .json::>() + .await?; + + + match response { + ApiResponse::Ok(response) => Ok(response.try_into()?), + ApiResponse::Err(err) => Err(CompletionError::ResponseError(err.message)), + } + } +} \ No newline at end of file From 45132c23a4624c2cca932ad0e48c0a3159d33ab9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20B=C3=A9langer?= Date: Fri, 11 Oct 2024 19:55:38 +0200 Subject: [PATCH 03/18] feat(provider-gemini): add gemini embedding support --- rig-core/src/providers/gemini/embedding.rs | 95 ++++++++++++++++++++++ rig-core/src/providers/gemini/mod.rs | 16 ++++ 2 files changed, 111 insertions(+) create mode 100644 rig-core/src/providers/gemini/embedding.rs create mode 100644 rig-core/src/providers/gemini/mod.rs diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs new file mode 100644 index 00000000..4b4ea2dc --- /dev/null +++ b/rig-core/src/providers/gemini/embedding.rs @@ -0,0 +1,95 @@ + +// ================================================================ +// Google Gemini Embeddings +// ================================================================ + +use serde::Deserialize; +use serde_json::json; + +use crate::embeddings::{self, EmbeddingError}; + +use super::{client::ApiResponse, Client}; + +/// `embedding-gecko-001` embedding model +pub const EMBEDDING_GECKO_001: &str = "embedding-gecko-001"; +/// `embedding-001` embedding model +pub const EMBEDDING_001: &str = "embedding-001"; +/// `text-embedding-004` embedding model +pub const EMBEDDING_004: &str = "text-embedding-004"; + +#[derive(Debug, Deserialize)] +pub struct EmbeddingResponse { + pub embedding: EmbeddingValues, +} + +#[derive(Debug, Deserialize)] +pub struct EmbeddingValues { + pub values: Vec, +} + +#[derive(Clone)] +pub struct EmbeddingModel { + client: Client, + model: String, + ndims: Option, +} + +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: Option) -> Self { + Self { client, model: model.to_string(), ndims } + } + + +} + +impl embeddings::EmbeddingModel for EmbeddingModel { + const MAX_DOCUMENTS: usize = 1024; + + fn ndims(&self) -> usize { + match self.model.as_str() { + EMBEDDING_GECKO_001 | EMBEDDING_001 => 768, + EMBEDDING_004 => 1024, + _ => 0, // Default to 0 for unknown models + } + } + + + async fn embed_documents( + &self, + documents: Vec, + ) -> Result, EmbeddingError> { + let mut request_body = json!({ + "model": format!("models/{}", self.model), + "content": { + "parts": documents.iter().map(|doc| json!({ "text": doc })).collect::>(), + }, + }); + + if let Some(ndims) = self.ndims { + request_body["output_dimensionality"] = json!(ndims); + } + + let response = self + .client + .post(&format!("/v1beta/models/{}:embedContent", self.model)) + .json(&request_body) + .send() + .await? + .error_for_status()? + .json::>() + .await?; + + match response { + ApiResponse::Ok(response) => { + let chunk_size = self.ndims.unwrap_or_else(|| self.ndims()); + Ok(documents.into_iter().zip(response.embedding.values.chunks(chunk_size)).map(|(document, embedding)| { + embeddings::Embedding { + document, + vec: embedding.to_vec(), + } + }).collect()) + } + ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), + } + } +} diff --git a/rig-core/src/providers/gemini/mod.rs b/rig-core/src/providers/gemini/mod.rs new file mode 100644 index 00000000..994fca36 --- /dev/null +++ b/rig-core/src/providers/gemini/mod.rs @@ -0,0 +1,16 @@ +//! Google API client and Rig integration +//! +//! # Example +//! ``` +//! use rig::providers::google; +//! +//! let client = google::Client::new("YOUR_API_KEY"); +//! +//! let gemini_embedding_model = client.embedding_model(google::EMBEDDING_001); +//! ``` + +pub mod client; +pub mod completion; +pub mod embedding; + +pub use client::Client; From 8b69e9042f6827f4c86f523908fffcf9aff9d8da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20B=C3=A9langer?= Date: Mon, 14 Oct 2024 15:06:13 +0200 Subject: [PATCH 04/18] feat(provider-gemini): add agent support in client --- rig-core/src/providers/gemini/client.rs | 41 ++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs index d8f28088..41f1595b 100644 --- a/rig-core/src/providers/gemini/client.rs +++ b/rig-core/src/providers/gemini/client.rs @@ -1,5 +1,6 @@ -use crate::embeddings::{self}; -use serde::Deserialize; +use crate::{agent::AgentBuilder, embeddings::{self}, extractor::ExtractorBuilder}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; use super::{completion::CompletionModel, embedding::EmbeddingModel}; @@ -36,6 +37,14 @@ impl Client { .expect("Gemini reqwest client should build"), } } + + /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable. + /// Panics if the environment variable is not set. + pub fn from_env() -> Self { + let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set"); + Self::new(&api_key) + } + pub fn post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/"); @@ -99,9 +108,31 @@ impl Client { CompletionModel::new(self.clone(), model) } - // pub fn agent(&self, model: &str) -> AgentBuilder { - // AgentBuilder::new(self.completion_model(model)) - // } + /// Create an agent builder with the given completion model. + /// + /// # Example + /// ``` + /// use rig::providers::gemini::{Client, self}; + /// + /// // Initialize the Google Gemini client + /// let gemini = Client::new("your-google-gemini-api-key"); + /// + /// let agent = gemini.agent(gemini::completion::GEMINI_1_5_PRO) + /// .preamble("You are comedian AI with a mission to make people laugh.") + /// .temperature(0.0) + /// .build(); + /// ``` + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + /// Create an extractor builder with the given completion model. + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } } #[derive(Debug, Deserialize)] From 700a196973d71f4e38d94e4eeddf955db7402154 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20B=C3=A9langer?= Date: Mon, 14 Oct 2024 15:07:30 +0200 Subject: [PATCH 05/18] docs(provider-gemini): Update readme entries, add gemini agent example --- README.md | 4 ++-- rig-core/README.md | 4 ++-- rig-core/examples/gemini_agent.rs | 25 +++++++++++++++++++++++++ rig-core/src/providers/mod.rs | 4 +++- 4 files changed, 32 insertions(+), 5 deletions(-) create mode 100644 rig-core/examples/gemini_agent.rs diff --git a/README.md b/README.md index fc00ad07..3b014446 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ We'd love your feedback. Please take a moment to let us know what you think usin ## High-level features - Full support for LLM completion and embedding workflows -- Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere) and vector stores (e.g. MongoDB, in-memory) +- Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere, Google Gemini) and vector stores (e.g. MongoDB, in-memory) - Integrate LLMs in your app with minimal boilerplate ## Installation @@ -70,6 +70,6 @@ or just `full` to enable all features (`cargo add tokio --features macros,rt-mul Rig supports the following LLM providers natively: - OpenAI - Cohere - +- Google Gemini Additionally, Rig currently has the following integration sub-libraries: - MongoDB vector store: `rig-mongodb` diff --git a/rig-core/README.md b/rig-core/README.md index 1f0d9d69..4bf40c15 100644 --- a/rig-core/README.md +++ b/rig-core/README.md @@ -11,7 +11,7 @@ More information about this crate can be found in the [crate documentation](http ## High-level features - Full support for LLM completion and embedding workflows -- Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere) and vector stores (e.g. MongoDB, in-memory) +- Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere, Google Gemini) and vector stores (e.g. MongoDB, in-memory) - Integrate LLMs in your app with minimal boilerplate ## Installation @@ -47,6 +47,6 @@ or just `full` to enable all features (`cargo add tokio --features macros,rt-mul Rig supports the following LLM providers natively: - OpenAI - Cohere - +- Google Gemini Additionally, Rig currently has the following integration sub-libraries: - MongoDB vector store: `rig-mongodb` diff --git a/rig-core/examples/gemini_agent.rs b/rig-core/examples/gemini_agent.rs new file mode 100644 index 00000000..f8e5af83 --- /dev/null +++ b/rig-core/examples/gemini_agent.rs @@ -0,0 +1,25 @@ +use rig::{completion::Prompt, providers::gemini}; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Initialize the Google Gemini client + // Create OpenAI client + let client = gemini::Client::from_env(); + + // Create agent with a single context prompt + let agent = client + .agent(gemini::completion::GEMINI_1_5_PRO) + .preamble("Be precise and concise.") + .temperature(0.5) + .max_tokens(8192) + .build(); + + // Prompt the agent and print the response + let response = agent + .prompt("When and where and what type is the next solar eclipse?") + .await?; + println!("{}", response); + + Ok(()) + +} \ No newline at end of file diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 84bee958..da7429a8 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -5,7 +5,8 @@ //! - OpenAI //! - Perplexity //! - Anthropic -//! +//! - Google Gemini +//! //! Each provider has its own module, which contains a `Client` implementation that can //! be used to initialize completion and embedding models and execute requests to those models. //! @@ -43,3 +44,4 @@ pub mod anthropic; pub mod cohere; pub mod openai; pub mod perplexity; +pub mod gemini; \ No newline at end of file From cc74aecd88fca9028c84b602e74cf7712ca3afd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20B=C3=A9langer?= Date: Mon, 14 Oct 2024 15:34:46 +0200 Subject: [PATCH 06/18] style(provider-gemini): test pre-commits --- rig-core/examples/gemini_agent.rs | 3 +- rig-core/src/providers/gemini/client.rs | 6 +- rig-core/src/providers/gemini/completion.rs | 71 ++++++++++++--------- rig-core/src/providers/gemini/embedding.rs | 20 +++--- rig-core/src/providers/gemini/mod.rs | 2 +- rig-core/src/providers/mod.rs | 4 +- 6 files changed, 62 insertions(+), 44 deletions(-) diff --git a/rig-core/examples/gemini_agent.rs b/rig-core/examples/gemini_agent.rs index f8e5af83..071fed23 100644 --- a/rig-core/examples/gemini_agent.rs +++ b/rig-core/examples/gemini_agent.rs @@ -21,5 +21,4 @@ async fn main() -> Result<(), anyhow::Error> { println!("{}", response); Ok(()) - -} \ No newline at end of file +} diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs index 41f1595b..dd762f69 100644 --- a/rig-core/src/providers/gemini/client.rs +++ b/rig-core/src/providers/gemini/client.rs @@ -1,4 +1,8 @@ -use crate::{agent::AgentBuilder, embeddings::{self}, extractor::ExtractorBuilder}; +use crate::{ + agent::AgentBuilder, + embeddings::{self}, + extractor::ExtractorBuilder, +}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index 5fd0bb89..ae716d8d 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -17,11 +17,13 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; -use crate::{completion::{self, CompletionError, CompletionRequest}, providers::gemini::client::ApiResponse}; +use crate::{ + completion::{self, CompletionError, CompletionRequest}, + providers::gemini::client::ApiResponse, +}; use super::Client; - // // Gemini API Response Types // @@ -247,7 +249,7 @@ pub enum HarmBlockThreshold { BlockMediumAndAbove, BlockOnlyHigh, BlockNone, - Off + Off, } // ================================================================= @@ -260,7 +262,7 @@ impl From for Tool { function_declaration: FunctionDeclaration { name: tool.name, description: tool.description, - parameters: None // tool.parameters, TODO: Map Gemini + parameters: None, // tool.parameters, TODO: Map Gemini }, code_execution: None, } @@ -282,24 +284,24 @@ impl CompletionModel { } } - impl TryFrom for completion::CompletionResponse { type Error = CompletionError; fn try_from(response: GenerateContentResponse) -> Result { match response.candidates.as_slice() { - [ContentCandidate { content, .. }, ..] => { - Ok(completion::CompletionResponse { - choice: completion::ModelChoice::Message(content.parts.first().unwrap().text.clone()), - raw_response: response, - }) - } - _ => Err(CompletionError::ResponseError("No candidates found in response".into())), + [ContentCandidate { content, .. }, ..] => Ok(completion::CompletionResponse { + choice: completion::ModelChoice::Message( + content.parts.first().unwrap().text.clone(), + ), + raw_response: response, + }), + _ => Err(CompletionError::ResponseError( + "No candidates found in response".into(), + )), } } } - impl completion::CompletionModel for CompletionModel { type Response = GenerateContentResponse; @@ -307,15 +309,16 @@ impl completion::CompletionModel for CompletionModel { &self, mut completion_request: CompletionRequest, ) -> Result, CompletionError> { - // QUESTION: Why do Anthropic/openAi implementation differ here? OpenAI adds the preamble but Anthropic does not. - + let mut full_history = if let Some(preamble) = &completion_request.preamble { vec![completion::Message { role: "system".into(), content: preamble.clone(), }] - } else { vec![] }; + } else { + vec![] + }; full_history.append(&mut completion_request.chat_history); @@ -327,15 +330,18 @@ impl completion::CompletionModel for CompletionModel { }); let request = GenerateContentRequest { - contents: full_history.into_iter().map(|msg| Content { - parts: vec![Part { text: msg.content }], - role: match msg.role.as_str() { - "system" => Some("model".to_string()), - "user" => Some("user".to_string()), - "assistant" => Some("model".to_string()), - _ => None, - }, - }).collect(), + contents: full_history + .into_iter() + .map(|msg| Content { + parts: vec![Part { text: msg.content }], + role: match msg.role.as_str() { + "system" => Some("model".to_string()), + "user" => Some("user".to_string()), + "assistant" => Some("model".to_string()), + _ => None, + }, + }) + .collect(), // QUESTION: How to handle API config specifics? generation_config: Some(GenerationConfig { temperature: completion_request.temperature, @@ -352,12 +358,20 @@ impl completion::CompletionModel for CompletionModel { logprobs: None, }), safety_settings: None, - tools: Some(completion_request.tools.into_iter().map(Tool::from).collect()), + tools: Some( + completion_request + .tools + .into_iter() + .map(Tool::from) + .collect(), + ), tool_config: None, system_instruction: None, }; - let response = self.client.post(&format!("/v1beta/models/{}:generateContent", self.model)) + let response = self + .client + .post(&format!("/v1beta/models/{}:generateContent", self.model)) .json(&request) .send() .await? @@ -365,10 +379,9 @@ impl completion::CompletionModel for CompletionModel { .json::>() .await?; - match response { ApiResponse::Ok(response) => Ok(response.try_into()?), ApiResponse::Err(err) => Err(CompletionError::ResponseError(err.message)), } } -} \ No newline at end of file +} diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs index 4b4ea2dc..1c08ae52 100644 --- a/rig-core/src/providers/gemini/embedding.rs +++ b/rig-core/src/providers/gemini/embedding.rs @@ -1,4 +1,3 @@ - // ================================================================ // Google Gemini Embeddings // ================================================================ @@ -36,10 +35,12 @@ pub struct EmbeddingModel { impl EmbeddingModel { pub fn new(client: Client, model: &str, ndims: Option) -> Self { - Self { client, model: model.to_string(), ndims } + Self { + client, + model: model.to_string(), + ndims, + } } - - } impl embeddings::EmbeddingModel for EmbeddingModel { @@ -53,7 +54,6 @@ impl embeddings::EmbeddingModel for EmbeddingModel { } } - async fn embed_documents( &self, documents: Vec, @@ -82,12 +82,14 @@ impl embeddings::EmbeddingModel for EmbeddingModel { match response { ApiResponse::Ok(response) => { let chunk_size = self.ndims.unwrap_or_else(|| self.ndims()); - Ok(documents.into_iter().zip(response.embedding.values.chunks(chunk_size)).map(|(document, embedding)| { - embeddings::Embedding { + Ok(documents + .into_iter() + .zip(response.embedding.values.chunks(chunk_size)) + .map(|(document, embedding)| embeddings::Embedding { document, vec: embedding.to_vec(), - } - }).collect()) + }) + .collect()) } ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), } diff --git a/rig-core/src/providers/gemini/mod.rs b/rig-core/src/providers/gemini/mod.rs index 994fca36..c3be2294 100644 --- a/rig-core/src/providers/gemini/mod.rs +++ b/rig-core/src/providers/gemini/mod.rs @@ -5,7 +5,7 @@ //! use rig::providers::google; //! //! let client = google::Client::new("YOUR_API_KEY"); -//! +//! //! let gemini_embedding_model = client.embedding_model(google::EMBEDDING_001); //! ``` diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index da7429a8..9d774459 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -6,7 +6,7 @@ //! - Perplexity //! - Anthropic //! - Google Gemini -//! +//! //! Each provider has its own module, which contains a `Client` implementation that can //! be used to initialize completion and embedding models and execute requests to those models. //! @@ -42,6 +42,6 @@ //! be used with the Cohere provider client. pub mod anthropic; pub mod cohere; +pub mod gemini; pub mod openai; pub mod perplexity; -pub mod gemini; \ No newline at end of file From d0e0ade139552648232c962377686f64c7486ac8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20B=C3=A9langer?= Date: Mon, 14 Oct 2024 17:09:00 +0200 Subject: [PATCH 07/18] feat(provider-gemini): add support for gemini specific completion parameters --- rig-core/examples/gemini_agent.rs | 16 +- rig-core/src/providers/gemini/client.rs | 6 +- rig-core/src/providers/gemini/completion.rs | 322 ++++++++++++++++---- 3 files changed, 290 insertions(+), 54 deletions(-) diff --git a/rig-core/examples/gemini_agent.rs b/rig-core/examples/gemini_agent.rs index 071fed23..f221886a 100644 --- a/rig-core/examples/gemini_agent.rs +++ b/rig-core/examples/gemini_agent.rs @@ -1,4 +1,7 @@ -use rig::{completion::Prompt, providers::gemini}; +use rig::{ + completion::Prompt, + providers::gemini::{self, completion::GenerationConfig}, +}; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -12,11 +15,20 @@ async fn main() -> Result<(), anyhow::Error> { .preamble("Be precise and concise.") .temperature(0.5) .max_tokens(8192) + .additional_params( + serde_json::to_value(GenerationConfig { + top_k: Some(1), + top_p: Some(0.95), + candidate_count: Some(1), + ..Default::default() + }) + .unwrap(), + ) // Unwrap the Result to get the Value .build(); // Prompt the agent and print the response let response = agent - .prompt("When and where and what type is the next solar eclipse?") + .prompt("How much wood would a woodchuck chuck if a woodchuck could chuck wood?") .await?; println!("{}", response); diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs index dd762f69..c5d58cfa 100644 --- a/rig-core/src/providers/gemini/client.rs +++ b/rig-core/src/providers/gemini/client.rs @@ -108,12 +108,16 @@ impl Client { embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) } + /// Create a completion model with the given name. + /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::GenerationConfig) struct. + /// https://ai.google.dev/api/generate-content#generationconfig pub fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } /// Create an agent builder with the given completion model. - /// + /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::GenerationConfig) struct. + /// https://ai.google.dev/api/generate-content#generationconfig /// # Example /// ``` /// use rig::providers::gemini::{Client, self}; diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index ae716d8d..8dae9a97 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -13,9 +13,10 @@ pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b"; /// `gemini-1.0-pro` completion model pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro"; -use std::collections::HashMap; - use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use std::convert::TryFrom; use crate::{ completion::{self, CompletionError, CompletionRequest}, @@ -24,9 +25,9 @@ use crate::{ use super::Client; -// -// Gemini API Response Types -// +// ================================================================= +// Gemini API Response Structures +// ================================================================= // Define the struct for the GenerateContentResponse #[derive(Debug, Deserialize)] @@ -165,22 +166,54 @@ pub struct LogProbCandidate { pub log_probability: f64, } +/// Gemini API Configuration options for model generation and outputs. Not all parameters are configurable for every model. +/// https://ai.google.dev/api/generate-content#generationconfig #[derive(Debug, Deserialize, Serialize)] pub struct GenerationConfig { + /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response. pub stop_sequences: Option>, + /// MIME type of the generated candidate text. Supported MIME types are: text/plain: (default) Text output. application/json: JSON response in the response candidates. text/x.enum: ENUM as a string response in the response candidates. Refer to the docs for a list of all supported text MIME types pub response_mime_type: Option, + /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be objects, primitives or arrays. If set, a compatible responseMimeType must also be set. Compatible MIME types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details. pub response_schema: Option, + /// Number of generated responses to return. Currently, this value can only be set to 1. If unset, this will default to 1. pub candidate_count: Option, + /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see the Model.output_token_limit attribute of the Model returned from the getModel function. pub max_output_tokens: Option, + /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0]. pub temperature: Option, + /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests. pub top_p: Option, + /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens. Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests. pub top_k: Option, + /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response. This penalty is binary on/off and not dependant on the number of times the token is used (after the first). Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will encourage the use of tokens that have already been used in the response, decreasing the vocabulary. pub presence_penalty: Option, + /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been seen in the respponse so far. A positive penalty will discourage the use of tokens that have already been used, proportional to the number of times the token has been used: The more a token is used, the more dificult it is for the model to use that token again increasing the vocabulary of responses. Caution: A negative penalty will encourage the model to reuse tokens proportional to the number of times the token has been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause the model to start repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...". pub frequency_penalty: Option, + /// If true, export the logprobs results in response. pub response_logprobs: Option, + /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in the Candidate.logprobs_result. pub logprobs: Option, } +impl Default for GenerationConfig { + fn default() -> Self { + Self { + temperature: Some(1.0), + max_output_tokens: Some(4096), + stop_sequences: None, + response_mime_type: None, + response_schema: None, + candidate_count: None, + top_p: None, + top_k: None, + presence_penalty: None, + frequency_penalty: None, + response_logprobs: None, + logprobs: None, + } + } +} /// The Schema object allows the definition of input and output data types. These types can be objects, but also primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object. /// https://ai.google.dev/api/caching#Schema #[derive(Debug, Deserialize, Serialize)] @@ -197,6 +230,63 @@ pub struct Schema { pub items: Option>, } +impl TryFrom for Schema { + type Error = CompletionError; + + fn try_from(value: Value) -> Result { + if let Some(obj) = value.as_object() { + Ok(Schema { + r#type: obj + .get("type") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(), + format: obj.get("format").and_then(|v| v.as_str()).map(String::from), + description: obj + .get("description") + .and_then(|v| v.as_str()) + .map(String::from), + nullable: obj.get("nullable").and_then(|v| v.as_bool()), + r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }), + max_items: obj + .get("maxItems") + .and_then(|v| v.as_i64()) + .map(|v| v as i32), + min_items: obj + .get("minItems") + .and_then(|v| v.as_i64()) + .map(|v| v as i32), + properties: obj + .get("properties") + .and_then(|v| v.as_object()) + .map(|map| { + map.iter() + .filter_map(|(k, v)| { + v.clone().try_into().ok().map(|schema| (k.clone(), schema)) + }) + .collect() + }), + required: obj.get("required").and_then(|v| v.as_array()).map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }), + items: obj + .get("items") + .map(|v| Box::new(v.clone().try_into().unwrap())), + }) + } else { + Err(CompletionError::ResponseError( + "Expected a JSON object for Schema".into(), + )) + } + } +} + #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentRequest { @@ -256,19 +346,6 @@ pub enum HarmBlockThreshold { // Rig Implementation Types // ================================================================= -impl From for Tool { - fn from(tool: completion::ToolDefinition) -> Self { - Self { - function_declaration: FunctionDeclaration { - name: tool.name, - description: tool.description, - parameters: None, // tool.parameters, TODO: Map Gemini - }, - code_execution: None, - } - } -} - #[derive(Clone)] pub struct CompletionModel { client: Client, @@ -284,24 +361,6 @@ impl CompletionModel { } } -impl TryFrom for completion::CompletionResponse { - type Error = CompletionError; - - fn try_from(response: GenerateContentResponse) -> Result { - match response.candidates.as_slice() { - [ContentCandidate { content, .. }, ..] => Ok(completion::CompletionResponse { - choice: completion::ModelChoice::Message( - content.parts.first().unwrap().text.clone(), - ), - raw_response: response, - }), - _ => Err(CompletionError::ResponseError( - "No candidates found in response".into(), - )), - } - } -} - impl completion::CompletionModel for CompletionModel { type Response = GenerateContentResponse; @@ -329,6 +388,20 @@ impl completion::CompletionModel for CompletionModel { content: prompt_with_context, }); + // Handle Gemini specific parameters + let mut generation_config = + GenerationConfig::try_from(completion_request.additional_params.unwrap_or_default())?; + + // Set temperature from completion_request or additional_params + if let Some(temp) = completion_request.temperature { + generation_config.temperature = Some(temp); + } + + // Set max_tokens from completion_request or additional_params + if let Some(max_tokens) = completion_request.max_tokens { + generation_config.max_output_tokens = Some(max_tokens); + } + let request = GenerateContentRequest { contents: full_history .into_iter() @@ -342,21 +415,7 @@ impl completion::CompletionModel for CompletionModel { }, }) .collect(), - // QUESTION: How to handle API config specifics? - generation_config: Some(GenerationConfig { - temperature: completion_request.temperature, - max_output_tokens: completion_request.max_tokens, - top_p: None, - top_k: None, - candidate_count: None, - stop_sequences: None, - response_mime_type: None, - response_schema: None, - presence_penalty: None, - frequency_penalty: None, - response_logprobs: None, - logprobs: None, - }), + generation_config: Some(generation_config), safety_settings: None, tools: Some( completion_request @@ -385,3 +444,164 @@ impl completion::CompletionModel for CompletionModel { } } } + +impl From for Tool { + fn from(tool: completion::ToolDefinition) -> Self { + Self { + function_declaration: FunctionDeclaration { + name: tool.name, + description: tool.description, + parameters: None, // tool.parameters, TODO: Map Gemini + }, + code_execution: None, + } + } +} + +impl TryFrom for completion::CompletionResponse { + type Error = CompletionError; + + fn try_from(response: GenerateContentResponse) -> Result { + match response.candidates.as_slice() { + [ContentCandidate { content, .. }, ..] => Ok(completion::CompletionResponse { + choice: completion::ModelChoice::Message( + content.parts.first().unwrap().text.clone(), + ), + raw_response: response, + }), + _ => Err(CompletionError::ResponseError( + "No candidates found in response".into(), + )), + } + } +} + +impl TryFrom for GenerationConfig { + type Error = CompletionError; + + fn try_from(value: serde_json::Value) -> Result { + let mut config = GenerationConfig { + temperature: None, + max_output_tokens: None, + stop_sequences: None, + response_mime_type: None, + response_schema: None, + candidate_count: None, + top_p: None, + top_k: None, + presence_penalty: None, + frequency_penalty: None, + response_logprobs: None, + logprobs: None, + }; + + fn unexpected_type_error(field: &str) -> CompletionError { + CompletionError::ResponseError(format!("Unexpected type for field '{}'", field)) + } + + if let Some(obj) = value.as_object() { + for (key, value) in obj.iter().filter(|(_, v)| !v.is_null()) { + match key.as_str() { + "temperature" => { + if !value.is_null() { + if let Some(v) = value.as_f64() { + config.temperature = Some(v); + } else { + return Err(unexpected_type_error("temperature")); + } + } + } + "max_output_tokens" => { + if let Some(v) = value.as_u64() { + config.max_output_tokens = Some(v); + } else { + return Err(unexpected_type_error("max_output_tokens")); + } + } + "top_p" => { + if let Some(v) = value.as_f64() { + config.top_p = Some(v); + } else { + return Err(unexpected_type_error("top_p")); + } + } + "top_k" => { + if let Some(v) = value.as_i64() { + config.top_k = Some(v as i32); + } else { + return Err(unexpected_type_error("top_k")); + } + } + "candidate_count" => { + if let Some(v) = value.as_i64() { + config.candidate_count = Some(v as i32); + } else { + return Err(unexpected_type_error("candidate_count")); + } + } + "stop_sequences" => { + if let Some(v) = value.as_array() { + config.stop_sequences = Some( + v.iter() + .filter_map(|s| s.as_str().map(String::from)) + .collect(), + ); + } else { + return Err(unexpected_type_error("stop_sequences")); + } + } + "response_mime_type" => { + if let Some(v) = value.as_str() { + config.response_mime_type = Some(v.to_string()); + } else { + return Err(unexpected_type_error("response_mime_type")); + } + } + "response_schema" => { + config.response_schema = Some(value.clone().try_into()?); + } + "presence_penalty" => { + if let Some(v) = value.as_f64() { + config.presence_penalty = Some(v); + } else { + return Err(unexpected_type_error("presence_penalty")); + } + } + "frequency_penalty" => { + if let Some(v) = value.as_f64() { + config.frequency_penalty = Some(v); + } else { + return Err(unexpected_type_error("frequency_penalty")); + } + } + "response_logprobs" => { + if let Some(v) = value.as_bool() { + config.response_logprobs = Some(v); + } else { + return Err(unexpected_type_error("response_logprobs")); + } + } + "logprobs" => { + if let Some(v) = value.as_i64() { + config.logprobs = Some(v as i32); + } else { + return Err(unexpected_type_error("logprobs")); + } + } + _ => { + tracing::warn!( + "Unknown GenerationConfig parameter, will be ignored: {}", + key + ); + } + } + } + } else { + return Err(CompletionError::ResponseError( + "Expected a JSON object for GenerationConfig".into(), + )); + } + + Ok(config) + } +} From e5e763ef717ea6c3223cee66f35171f0b504d449 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20B=C3=A9langer?= Date: Mon, 14 Oct 2024 18:58:56 +0200 Subject: [PATCH 08/18] docs(gemini): add addtionnal types from the official documentation, add embeddings example --- rig-core/examples/gemini_embeddings.rs | 20 +++ rig-core/src/providers/gemini/completion.rs | 82 +++++++--- rig-core/src/providers/gemini/embedding.rs | 167 ++++++++++++++++++-- rig-core/src/providers/gemini/mod.rs | 2 +- 4 files changed, 238 insertions(+), 33 deletions(-) create mode 100644 rig-core/examples/gemini_embeddings.rs diff --git a/rig-core/examples/gemini_embeddings.rs b/rig-core/examples/gemini_embeddings.rs new file mode 100644 index 00000000..beeb6a20 --- /dev/null +++ b/rig-core/examples/gemini_embeddings.rs @@ -0,0 +1,20 @@ +use rig::providers::gemini::{self}; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Initialize the Google Gemini client + // Create OpenAI client + let client = gemini::Client::from_env(); + + let embeddings = client + .embeddings(gemini::embedding::EMBEDDING_001) + .simple_document("doc0", "Hello, world!") + .simple_document("doc1", "Goodbye, world!") + .build() + .await + .expect("Failed to embed documents"); + + println!("{:?}", embeddings); + + Ok(()) +} diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index 8dae9a97..12cd87be 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -1,8 +1,7 @@ // ================================================================ -// Google Gemini Completion API +//! Google Gemini Completion Integration +//! https://ai.google.dev/api/generate-content // ================================================================ -// -// https://ai.google.dev/api/generate-conten /// `gemini-1.5-flash` completion model pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash"; @@ -26,7 +25,7 @@ use crate::{ use super::Client; // ================================================================= -// Gemini API Response Structures +// Gemini API Types // ================================================================= // Define the struct for the GenerateContentResponse @@ -166,33 +165,63 @@ pub struct LogProbCandidate { pub log_probability: f64, } -/// Gemini API Configuration options for model generation and outputs. Not all parameters are configurable for every model. -/// https://ai.google.dev/api/generate-content#generationconfig +/// Gemini API Configuration options for model generation and outputs. Not all parameters are +/// configurable for every model. https://ai.google.dev/api/generate-content#generationconfig #[derive(Debug, Deserialize, Serialize)] pub struct GenerationConfig { - /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response. + /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop + /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response. pub stop_sequences: Option>, - /// MIME type of the generated candidate text. Supported MIME types are: text/plain: (default) Text output. application/json: JSON response in the response candidates. text/x.enum: ENUM as a string response in the response candidates. Refer to the docs for a list of all supported text MIME types + /// MIME type of the generated candidate text. Supported MIME types are: + /// - text/plain: (default) Text output + /// - application/json: JSON response in the response candidates. + /// - text/x.enum: ENUM as a string response in the response candidates. + /// Refer to the docs for a list of all supported text MIME types pub response_mime_type: Option, - /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be objects, primitives or arrays. If set, a compatible responseMimeType must also be set. Compatible MIME types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details. + /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be + /// objects, primitives or arrays. If set, a compatible responseMimeType must also be set. Compatible MIME + /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details. pub response_schema: Option, - /// Number of generated responses to return. Currently, this value can only be set to 1. If unset, this will default to 1. + /// Number of generated responses to return. Currently, this value can only be set to 1. If + /// unset, this will default to 1. pub candidate_count: Option, - /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see the Model.output_token_limit attribute of the Model returned from the getModel function. + /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see + /// the Model.output_token_limit attribute of the Model returned from the getModel function. pub max_output_tokens: Option, - /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0]. + /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature + /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0]. pub temperature: Option, - /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests. + /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and + /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most + /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while + /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value + /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty + /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests. pub top_p: Option, - /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens. Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests. + /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a + /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens. + /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is + /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates + /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests. pub top_k: Option, - /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response. This penalty is binary on/off and not dependant on the number of times the token is used (after the first). Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will encourage the use of tokens that have already been used in the response, decreasing the vocabulary. + /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response. + /// This penalty is binary on/off and not dependant on the number of times the token is used (after the first). + /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use + /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will + /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary. pub presence_penalty: Option, - /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been seen in the respponse so far. A positive penalty will discourage the use of tokens that have already been used, proportional to the number of times the token has been used: The more a token is used, the more dificult it is for the model to use that token again increasing the vocabulary of responses. Caution: A negative penalty will encourage the model to reuse tokens proportional to the number of times the token has been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause the model to start repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...". + /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been + /// seen in the respponse so far. A positive penalty will discourage the use of tokens that have already been + /// used, proportional to the number of times the token has been used: The more a token is used, the more + /// dificult it is for the model to use that token again increasing the vocabulary of responses. Caution: A + /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has + /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause + /// the model to repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...". pub frequency_penalty: Option, /// If true, export the logprobs results in response. pub response_logprobs: Option, - /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in the Candidate.logprobs_result. + /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in + /// [Candidate.logprobs_result]. pub logprobs: Option, } @@ -214,7 +243,8 @@ impl Default for GenerationConfig { } } } -/// The Schema object allows the definition of input and output data types. These types can be objects, but also primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object. +/// The Schema object allows the definition of input and output data types. These types can be objects, but also +/// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object. /// https://ai.google.dev/api/caching#Schema #[derive(Debug, Deserialize, Serialize)] pub struct Schema { @@ -293,8 +323,24 @@ pub struct GenerateContentRequest { pub contents: Vec, pub tools: Option>, pub tool_config: Option, + /// Optional. Configuration options for model generation and outputs. pub generation_config: Option, + /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the + /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one + /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the + /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified + /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API + /// will use the default safety setting for that category. Harm categories: + /// - HARM_CATEGORY_HATE_SPEECH, + /// - HARM_CATEGORY_SEXUALLY_EXPLICIT + /// - HARM_CATEGORY_DANGEROUS_CONTENT + /// - HARM_CATEGORY_HARASSMENT + /// are supported. + /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance + /// to learn how to incorporate safety considerations in your AI applications. pub safety_settings: Option>, + /// Optional. Developer set system instruction(s). Currently, text only. + /// https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest pub system_instruction: Option, // cachedContent: Optional } diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs index 1c08ae52..b0956286 100644 --- a/rig-core/src/providers/gemini/embedding.rs +++ b/rig-core/src/providers/gemini/embedding.rs @@ -1,31 +1,170 @@ // ================================================================ -// Google Gemini Embeddings +//! Google Gemini Embeddings Integration +//! https://ai.google.dev/api/embeddings // ================================================================ -use serde::Deserialize; use serde_json::json; use crate::embeddings::{self, EmbeddingError}; use super::{client::ApiResponse, Client}; -/// `embedding-gecko-001` embedding model -pub const EMBEDDING_GECKO_001: &str = "embedding-gecko-001"; /// `embedding-001` embedding model pub const EMBEDDING_001: &str = "embedding-001"; /// `text-embedding-004` embedding model pub const EMBEDDING_004: &str = "text-embedding-004"; -#[derive(Debug, Deserialize)] -pub struct EmbeddingResponse { - pub embedding: EmbeddingValues, -} +#[allow(dead_code)] +mod gemini_api_types { + use serde::{Deserialize, Serialize}; + use serde_json::Value; -#[derive(Debug, Deserialize)] -pub struct EmbeddingValues { - pub values: Vec, -} + #[derive(Serialize)] + #[serde(rename_all = "camelCase")] + pub struct EmbedContentRequest { + model: String, + content: EmbeddingContent, + task_type: TaskType, + title: String, + output_dimensionality: i32, + } + + #[derive(Serialize)] + pub struct EmbeddingContent { + parts: Vec, + /// Optional. The producer of the content. Must be either 'user' or 'model'. Useful to set for multi-turn + /// conversations, otherwise can be left blank or unset. + role: Option, + } + + /// A datatype containing media that is part of a multi-part Content message. + /// - A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data. + /// - A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes. + #[derive(Serialize)] + pub struct EmbeddingContentPart { + /// Inline text. + text: String, + /// Inline media bytes. + inline_data: Option, + /// A predicted FunctionCall returned from the model that contains a string representing the [FunctionDeclaration.name] + /// with the arguments and their values. + function_call: Option, + /// The result output of a FunctionCall that contains a string representing the [FunctionDeclaration.name] and a structured + /// JSON object containing any output from the function is used as context to the model. + function_response: Option, + /// URI based data. + file_data: Option, + /// Code generated by the model that is meant to be executed. + executable_code: Option, + /// Result of executing the ExecutableCode. + code_execution_result: Option, + } + + /// Raw media bytes. + /// Text should not be sent as raw bytes, use the 'text' field. + #[derive(Serialize)] + pub struct Blob { + /// Raw bytes for media formats.A base64-encoded string. + data: String, + /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg If an unsupported MIME type is + /// provided, an error will be returned. For a complete list of supported types, see Supported file formats. + mime_type: String, + } + + #[derive(Serialize)] + pub struct FunctionCall { + /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63. + name: String, + /// The function parameters and values in JSON object format. + args: Option, + } + + #[derive(Serialize)] + pub struct FunctionResponse { + /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63. + name: String, + /// The result of the function call in JSON object format. + result: Value, + } + + #[derive(Serialize)] + #[serde(rename_all = "camelCase")] + pub struct FileData { + /// The URI of the file. + file_uri: String, + /// The IANA standard MIME type of the source data. + mime_type: String, + } + #[derive(Serialize)] + pub struct ExecutableCode { + /// The language of the code. + language: ExecutionLanguage, + /// The code to execute. + code: String, + } + + #[derive(Serialize)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum ExecutionLanguage { + /// Unspecified language. This value should not be used. + LanguageUnspecified, + /// Python >= 3.10, with numpy and simpy available. + Python, + } + + #[derive(Serialize)] + pub struct CodeExecutionResult { + /// Outcome of the code execution. + outcome: CodeExecutionOutcome, + /// Contains stdout when code execution is successful, stderr or other description otherwise. + output: Option, + } + + #[derive(Serialize)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum CodeExecutionOutcome { + /// Unspecified status. This value should not be used. + Unspecified, + /// Code execution completed successfully. + Ok, + /// Code execution finished but with a failure. stderr should contain the reason. + Failed, + /// Code execution ran for too long, and was cancelled. There may or may not be a partial output present. + DeadlineExceeded, + } + + #[derive(Serialize)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum TaskType { + /// Unset value, which will default to one of the other enum values. + Unspecified, + /// Specifies the given text is a query in a search/retrieval setting. + RetrievalQuery, + /// Specifies the given text is a document from the corpus being searched. + RetrievalDocument, + /// Specifies the given text will be used for STS. + SemanticSimilarity, + /// Specifies that the given text will be classified. + Classification, + /// Specifies that the embeddings will be used for clustering. + Clustering, + /// Specifies that the given text will be used for question answering. + QuestionAnswering, + /// Specifies that the given text will be used for fact verification. + FactVerification, + } + + #[derive(Debug, Deserialize)] + pub struct EmbeddingResponse { + pub embedding: EmbeddingValues, + } + + #[derive(Debug, Deserialize)] + pub struct EmbeddingValues { + pub values: Vec, + } +} #[derive(Clone)] pub struct EmbeddingModel { client: Client, @@ -48,7 +187,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { fn ndims(&self) -> usize { match self.model.as_str() { - EMBEDDING_GECKO_001 | EMBEDDING_001 => 768, + EMBEDDING_001 => 768, EMBEDDING_004 => 1024, _ => 0, // Default to 0 for unknown models } @@ -76,7 +215,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { .send() .await? .error_for_status()? - .json::>() + .json::>() .await?; match response { diff --git a/rig-core/src/providers/gemini/mod.rs b/rig-core/src/providers/gemini/mod.rs index c3be2294..cd3d1b91 100644 --- a/rig-core/src/providers/gemini/mod.rs +++ b/rig-core/src/providers/gemini/mod.rs @@ -1,4 +1,4 @@ -//! Google API client and Rig integration +//! Google Gemini API client and Rig integration //! //! # Example //! ``` From 1dca1dac62844c70b0e3f431ce2249c5bfc9a41f Mon Sep 17 00:00:00 2001 From: Mathieu Date: Tue, 15 Oct 2024 15:38:52 +0200 Subject: [PATCH 09/18] feat(gemini): move system prompt to correct request field --- rig-core/examples/gemini_agent.rs | 21 +++++++++--------- rig-core/src/providers/gemini/completion.rs | 24 +++++++++------------ 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/rig-core/examples/gemini_agent.rs b/rig-core/examples/gemini_agent.rs index f221886a..978a7cff 100644 --- a/rig-core/examples/gemini_agent.rs +++ b/rig-core/examples/gemini_agent.rs @@ -12,23 +12,22 @@ async fn main() -> Result<(), anyhow::Error> { // Create agent with a single context prompt let agent = client .agent(gemini::completion::GEMINI_1_5_PRO) - .preamble("Be precise and concise.") + .preamble("Be creative and concise. Answer directly and clearly.") .temperature(0.5) .max_tokens(8192) - .additional_params( - serde_json::to_value(GenerationConfig { - top_k: Some(1), - top_p: Some(0.95), - candidate_count: Some(1), - ..Default::default() - }) - .unwrap(), - ) // Unwrap the Result to get the Value + .additional_params(serde_json::to_value(GenerationConfig { + top_k: Some(1), + top_p: Some(0.95), + candidate_count: Some(1), + ..Default::default() + })?) // Unwrap the Result to get the Value .build(); + tracing::info!("Prompting the agent..."); + // Prompt the agent and print the response let response = agent - .prompt("How much wood would a woodchuck chuck if a woodchuck could chuck wood?") + .prompt("How much wood would a woodchuck chuck if a woodchuck could chuck wood? Infer an answer.") .await?; println!("{}", response); diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index 12cd87be..c535a56c 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -36,7 +36,6 @@ pub struct GenerateContentResponse { pub usage_metadata: Option, } -// Define the struct for a Candidate #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ContentCandidate { @@ -341,7 +340,7 @@ pub struct GenerateContentRequest { pub safety_settings: Option>, /// Optional. Developer set system instruction(s). Currently, text only. /// https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest - pub system_instruction: Option, + pub system_instruction: Option, // cachedContent: Optional } @@ -414,17 +413,7 @@ impl completion::CompletionModel for CompletionModel { &self, mut completion_request: CompletionRequest, ) -> Result, CompletionError> { - // QUESTION: Why do Anthropic/openAi implementation differ here? OpenAI adds the preamble but Anthropic does not. - - let mut full_history = if let Some(preamble) = &completion_request.preamble { - vec![completion::Message { - role: "system".into(), - content: preamble.clone(), - }] - } else { - vec![] - }; - + let mut full_history = Vec::new(); full_history.append(&mut completion_request.chat_history); let prompt_with_context = completion_request.prompt_with_context(); @@ -471,9 +460,16 @@ impl completion::CompletionModel for CompletionModel { .collect(), ), tool_config: None, - system_instruction: None, + system_instruction: Some(Content { + parts: vec![Part { + text: "system".to_string(), + }], + role: Some("system".to_string()), + }), }; + tracing::info!("Request: {:?}", request); + let response = self .client .post(&format!("/v1beta/models/{}:generateContent", self.model)) From 05b5df13ecc95aca71f50766ea68803871e63fb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20B=C3=A9langer?= Date: Tue, 15 Oct 2024 15:52:56 +0200 Subject: [PATCH 10/18] docs(readme): remove gemini mention in non-exhaustive list --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 3b014446..c3f8aaee 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ We'd love your feedback. Please take a moment to let us know what you think usin ## High-level features - Full support for LLM completion and embedding workflows -- Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere, Google Gemini) and vector stores (e.g. MongoDB, in-memory) +- Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere) and vector stores (e.g. MongoDB, in-memory) - Integrate LLMs in your app with minimal boilerplate ## Installation @@ -71,5 +71,6 @@ Rig supports the following LLM providers natively: - OpenAI - Cohere - Google Gemini + Additionally, Rig currently has the following integration sub-libraries: - MongoDB vector store: `rig-mongodb` From 5b45c5cfe85b21e00dcb0e62d1756b4a26de22f8 Mon Sep 17 00:00:00 2001 From: Mathieu Date: Tue, 15 Oct 2024 15:55:27 +0200 Subject: [PATCH 11/18] chore: add debug trait to embedding struct --- rig-core/src/embeddings.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index e805d6d3..7a34e822 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -28,12 +28,12 @@ //! //! // Create an embeddings builder and add documents //! let embeddings = EmbeddingsBuilder::new(embedding_model) -//! .simple_document("doc1", "This is the first document.") +//! .simple_document("doc1", "This is the first document.") //! .simple_document("doc2", "This is the second document.") //! .build() //! .await //! .expect("Failed to build embeddings."); -//! +//! //! // Use the generated embeddings //! // ... //! ``` @@ -102,7 +102,7 @@ pub trait EmbeddingModel: Clone + Sync + Send { } /// Struct that holds a single document and its embedding. -#[derive(Clone, Default, Deserialize, Serialize)] +#[derive(Clone, Default, Deserialize, Serialize, Debug)] pub struct Embedding { /// The document that was embedded pub document: String, @@ -142,7 +142,7 @@ impl Embedding { /// large document to be retrieved from a query that matches multiple smaller and /// distinct text documents. For example, if the document is a textbook, a summary of /// each chapter could serve as the book's embeddings. -#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] pub struct DocumentEmbeddings { #[serde(rename = "_id")] pub id: String, From 9ae5e33b75d3843d01696ce5bb0784c30ad54146 Mon Sep 17 00:00:00 2001 From: Mathieu Date: Fri, 1 Nov 2024 00:05:10 +0100 Subject: [PATCH 12/18] refactor(gemini): separate gemini api types module, fix pr comments --- rig-core/README.md | 2 +- rig-core/examples/gemini_agent.rs | 3 +- rig-core/examples/gemini_embeddings.rs | 2 +- rig-core/src/providers/gemini/client.rs | 2 +- rig-core/src/providers/gemini/completion.rs | 894 +++++++++++--------- rig-core/src/providers/gemini/embedding.rs | 182 ++-- rig-core/src/providers/gemini/mod.rs | 45 +- 7 files changed, 640 insertions(+), 490 deletions(-) diff --git a/rig-core/README.md b/rig-core/README.md index 4bf40c15..567bb9e4 100644 --- a/rig-core/README.md +++ b/rig-core/README.md @@ -11,7 +11,7 @@ More information about this crate can be found in the [crate documentation](http ## High-level features - Full support for LLM completion and embedding workflows -- Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere, Google Gemini) and vector stores (e.g. MongoDB, in-memory) +- Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere) and vector stores (e.g. MongoDB, in-memory) - Integrate LLMs in your app with minimal boilerplate ## Installation diff --git a/rig-core/examples/gemini_agent.rs b/rig-core/examples/gemini_agent.rs index 978a7cff..3e8e66e9 100644 --- a/rig-core/examples/gemini_agent.rs +++ b/rig-core/examples/gemini_agent.rs @@ -1,12 +1,11 @@ use rig::{ completion::Prompt, - providers::gemini::{self, completion::GenerationConfig}, + providers::gemini::{self, completion::gemini_api_types::GenerationConfig}, }; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize the Google Gemini client - // Create OpenAI client let client = gemini::Client::from_env(); // Create agent with a single context prompt diff --git a/rig-core/examples/gemini_embeddings.rs b/rig-core/examples/gemini_embeddings.rs index beeb6a20..4ce24636 100644 --- a/rig-core/examples/gemini_embeddings.rs +++ b/rig-core/examples/gemini_embeddings.rs @@ -1,4 +1,4 @@ -use rig::providers::gemini::{self}; +use rig::providers::gemini; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs index c5d58cfa..0055af33 100644 --- a/rig-core/src/providers/gemini/client.rs +++ b/rig-core/src/providers/gemini/client.rs @@ -52,7 +52,7 @@ impl Client { pub fn post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/"); - tracing::info!("POST {}", url); + tracing::debug!("POST {}", url); self.http_client.post(url) } diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index c535a56c..f8bee2ec 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -12,9 +12,10 @@ pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b"; /// `gemini-1.0-pro` completion model pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro"; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::collections::HashMap; +use gemini_api_types::{ + Content, ContentCandidate, FunctionDeclaration, GenerateContentRequest, + GenerateContentResponse, GenerationConfig, Part, Role, Tool, +}; use std::convert::TryFrom; use crate::{ @@ -24,369 +25,6 @@ use crate::{ use super::Client; -// ================================================================= -// Gemini API Types -// ================================================================= - -// Define the struct for the GenerateContentResponse -#[derive(Debug, Deserialize)] -pub struct GenerateContentResponse { - pub candidates: Vec, - pub prompt_feedback: Option, - pub usage_metadata: Option, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ContentCandidate { - pub content: Content, - pub finish_reason: Option, - pub safety_ratings: Option>, - pub citation_metadata: Option, - pub token_count: Option, - pub avg_logprobs: Option, - pub logprobs_result: Option, - pub index: Option, -} - -#[derive(Debug, Deserialize, Serialize)] -pub struct Content { - pub parts: Vec, - pub role: Option, -} - -#[derive(Debug, Deserialize, Serialize)] -pub struct Part { - pub text: String, -} - -#[derive(Debug, Deserialize, Serialize)] -pub struct SafetyRating { - pub category: HarmCategory, - pub probability: HarmProbability, -} - -#[derive(Debug, Deserialize, Serialize)] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum HarmProbability { - HarmProbabilityUnspecified, - Negligible, - Low, - Medium, - High, -} - -#[derive(Debug, Deserialize, Serialize)] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum HarmCategory { - HarmCategoryUnspecified, - HarmCategoryDerogatory, - HarmCategoryToxicity, - HarmCategoryViolence, - HarmCategorySexually, - HarmCategoryMedical, - HarmCategoryDangerous, - HarmCategoryHarassment, - HarmCategoryHateSpeech, - HarmCategorySexuallyExplicit, - HarmCategoryDangerousContent, - HarmCategoryCivicIntegrity, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct UsageMetadata { - pub prompt_token_count: i32, - pub cached_content_token_count: i32, - pub candidates_token_count: i32, - pub total_token_count: i32, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct PromptFeedback { - pub block_reason: Option, - pub safety_ratings: Option>, -} - -#[derive(Debug, Deserialize)] -pub enum BlockReason { - BlockReasonUnspecified, - Safety, - Other, - Blocklist, - ProhibitedContent, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum FinishReason { - FinishReasonUnspecified, - Stop, - MaxTokens, - Safety, - Recitation, - Language, - Other, - Blocklist, - ProhibitedContent, -} - -#[derive(Debug, Deserialize)] -pub struct CitationMetadata { - pub citation_sources: Vec, -} - -#[derive(Debug, Deserialize)] -pub struct CitationSource { - pub uri: Option, - pub start_index: Option, - pub end_index: Option, - pub license: Option, -} - -#[derive(Debug, Deserialize)] -pub struct LogprobsResult { - pub top_candidate: Vec, - pub chosen_candidate: Vec, -} - -#[derive(Debug, Deserialize)] -pub struct TopCandidate { - pub candidates: Vec, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct LogProbCandidate { - pub token: String, - pub token_id: String, - pub log_probability: f64, -} - -/// Gemini API Configuration options for model generation and outputs. Not all parameters are -/// configurable for every model. https://ai.google.dev/api/generate-content#generationconfig -#[derive(Debug, Deserialize, Serialize)] -pub struct GenerationConfig { - /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop - /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response. - pub stop_sequences: Option>, - /// MIME type of the generated candidate text. Supported MIME types are: - /// - text/plain: (default) Text output - /// - application/json: JSON response in the response candidates. - /// - text/x.enum: ENUM as a string response in the response candidates. - /// Refer to the docs for a list of all supported text MIME types - pub response_mime_type: Option, - /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be - /// objects, primitives or arrays. If set, a compatible responseMimeType must also be set. Compatible MIME - /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details. - pub response_schema: Option, - /// Number of generated responses to return. Currently, this value can only be set to 1. If - /// unset, this will default to 1. - pub candidate_count: Option, - /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see - /// the Model.output_token_limit attribute of the Model returned from the getModel function. - pub max_output_tokens: Option, - /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature - /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0]. - pub temperature: Option, - /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and - /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most - /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while - /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value - /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty - /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests. - pub top_p: Option, - /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a - /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens. - /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is - /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates - /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests. - pub top_k: Option, - /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response. - /// This penalty is binary on/off and not dependant on the number of times the token is used (after the first). - /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use - /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will - /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary. - pub presence_penalty: Option, - /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been - /// seen in the respponse so far. A positive penalty will discourage the use of tokens that have already been - /// used, proportional to the number of times the token has been used: The more a token is used, the more - /// dificult it is for the model to use that token again increasing the vocabulary of responses. Caution: A - /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has - /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause - /// the model to repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...". - pub frequency_penalty: Option, - /// If true, export the logprobs results in response. - pub response_logprobs: Option, - /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in - /// [Candidate.logprobs_result]. - pub logprobs: Option, -} - -impl Default for GenerationConfig { - fn default() -> Self { - Self { - temperature: Some(1.0), - max_output_tokens: Some(4096), - stop_sequences: None, - response_mime_type: None, - response_schema: None, - candidate_count: None, - top_p: None, - top_k: None, - presence_penalty: None, - frequency_penalty: None, - response_logprobs: None, - logprobs: None, - } - } -} -/// The Schema object allows the definition of input and output data types. These types can be objects, but also -/// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object. -/// https://ai.google.dev/api/caching#Schema -#[derive(Debug, Deserialize, Serialize)] -pub struct Schema { - pub r#type: String, - pub format: Option, - pub description: Option, - pub nullable: Option, - pub r#enum: Option>, - pub max_items: Option, - pub min_items: Option, - pub properties: Option>, - pub required: Option>, - pub items: Option>, -} - -impl TryFrom for Schema { - type Error = CompletionError; - - fn try_from(value: Value) -> Result { - if let Some(obj) = value.as_object() { - Ok(Schema { - r#type: obj - .get("type") - .and_then(|v| v.as_str()) - .unwrap_or_default() - .to_string(), - format: obj.get("format").and_then(|v| v.as_str()).map(String::from), - description: obj - .get("description") - .and_then(|v| v.as_str()) - .map(String::from), - nullable: obj.get("nullable").and_then(|v| v.as_bool()), - r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| { - arr.iter() - .filter_map(|v| v.as_str().map(String::from)) - .collect() - }), - max_items: obj - .get("maxItems") - .and_then(|v| v.as_i64()) - .map(|v| v as i32), - min_items: obj - .get("minItems") - .and_then(|v| v.as_i64()) - .map(|v| v as i32), - properties: obj - .get("properties") - .and_then(|v| v.as_object()) - .map(|map| { - map.iter() - .filter_map(|(k, v)| { - v.clone().try_into().ok().map(|schema| (k.clone(), schema)) - }) - .collect() - }), - required: obj.get("required").and_then(|v| v.as_array()).map(|arr| { - arr.iter() - .filter_map(|v| v.as_str().map(String::from)) - .collect() - }), - items: obj - .get("items") - .map(|v| Box::new(v.clone().try_into().unwrap())), - }) - } else { - Err(CompletionError::ResponseError( - "Expected a JSON object for Schema".into(), - )) - } - } -} - -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct GenerateContentRequest { - pub contents: Vec, - pub tools: Option>, - pub tool_config: Option, - /// Optional. Configuration options for model generation and outputs. - pub generation_config: Option, - /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the - /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one - /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the - /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified - /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API - /// will use the default safety setting for that category. Harm categories: - /// - HARM_CATEGORY_HATE_SPEECH, - /// - HARM_CATEGORY_SEXUALLY_EXPLICIT - /// - HARM_CATEGORY_DANGEROUS_CONTENT - /// - HARM_CATEGORY_HARASSMENT - /// are supported. - /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance - /// to learn how to incorporate safety considerations in your AI applications. - pub safety_settings: Option>, - /// Optional. Developer set system instruction(s). Currently, text only. - /// https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest - pub system_instruction: Option, - // cachedContent: Optional -} - -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct Tool { - pub function_declaration: FunctionDeclaration, - pub code_execution: Option, -} - -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct FunctionDeclaration { - pub name: String, - pub description: String, - pub parameters: Option>, -} - -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct ToolConfig { - pub schema: Option, -} - -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct CodeExecution {} - -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct SafetySetting { - pub category: HarmCategory, - pub threshold: HarmBlockThreshold, -} - -#[derive(Debug, Serialize)] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum HarmBlockThreshold { - HarmBlockThresholdUnspecified, - BlockLowAndAbove, - BlockMediumAndAbove, - BlockOnlyHigh, - BlockNone, - Off, -} - // ================================================================= // Rig Implementation Types // ================================================================= @@ -437,15 +75,23 @@ impl completion::CompletionModel for CompletionModel { generation_config.max_output_tokens = Some(max_tokens); } + /* + serde_json::to_value(GenerationConfig { + top_k: Some(1), + top_p: Some(0.95), + candidate_count: Some(1), + ..Default::default() + })? */ + let request = GenerateContentRequest { contents: full_history .into_iter() .map(|msg| Content { - parts: vec![Part { text: msg.content }], + parts: vec![Part::Text(msg.content)], role: match msg.role.as_str() { - "system" => Some("model".to_string()), - "user" => Some("user".to_string()), - "assistant" => Some("model".to_string()), + "system" => Some(Role::Model), + "user" => Some(Role::User), + "assistant" => Some(Role::Model), _ => None, }, }) @@ -461,10 +107,8 @@ impl completion::CompletionModel for CompletionModel { ), tool_config: None, system_instruction: Some(Content { - parts: vec![Part { - text: "system".to_string(), - }], - role: Some("system".to_string()), + parts: vec![Part::Text("system".to_string())], + role: Some(Role::Model), }), }; @@ -506,9 +150,20 @@ impl TryFrom for completion::CompletionResponse Result { match response.candidates.as_slice() { [ContentCandidate { content, .. }, ..] => Ok(completion::CompletionResponse { - choice: completion::ModelChoice::Message( - content.parts.first().unwrap().text.clone(), - ), + choice: match content.parts.first().unwrap() { + Part::Text(text) => completion::ModelChoice::Message(text.clone()), + Part::FunctionCall(function_call) => { + let args_value = serde_json::Value::Object( + function_call.args.clone().unwrap_or_default(), + ); + completion::ModelChoice::ToolCall(function_call.name.clone(), args_value) + } + _ => { + return Err(CompletionError::ResponseError( + "Unsupported response by the model of type ".into(), + )) + } + }, raw_response: response, }), _ => Err(CompletionError::ResponseError( @@ -647,3 +302,488 @@ impl TryFrom for GenerationConfig { Ok(config) } } + +pub mod gemini_api_types { + + use std::collections::HashMap; + + // ================================================================= + // Gemini API Types + // ================================================================= + use serde::{Deserialize, Serialize}; + use serde_json::{Map, Value}; + + use crate::{ + completion::CompletionError, + providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode}, + }; + + /// Response from the model supporting multiple candidate responses. + /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback + /// and for each candidate in finishReason and in safetyRatings. + /// The API: + /// - Returns either all requested candidates or none of them + /// - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback) + /// - Reports feedback on each candidate in finishReason and safetyRatings. + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct GenerateContentResponse { + /// Candidate responses from the model. + pub candidates: Vec, + /// Returns the prompt's feedback related to the content filters. + pub prompt_feedback: Option, + /// Output only. Metadata on the generation requests' token usage. + pub usage_metadata: Option, + } + + /// A response candidate generated from the model. + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct ContentCandidate { + /// Output only. Generated content returned from the model. + pub content: Content, + /// Optional. Output only. The reason why the model stopped generating tokens. + /// If empty, the model has not stopped generating tokens. + pub finish_reason: Option, + /// List of ratings for the safety of a response candidate. + /// There is at most one rating per category. + pub safety_ratings: Option>, + /// Output only. Citation information for model-generated candidate. + /// This field may be populated with recitation information for any text included in the content. + /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data. + pub citation_metadata: Option, + /// Output only. Token count for this candidate. + pub token_count: Option, + /// Output only. + pub avg_logprobs: Option, + /// Output only. Log-likelihood scores for the response tokens and top tokens + pub logprobs_result: Option, + /// Output only. Index of the candidate in the list of response candidates. + pub index: Option, + } + + #[derive(Debug, Deserialize, Serialize)] + pub struct Content { + /// Ordered Parts that constitute a single message. Parts may have different MIME types. + pub parts: Vec, + /// The producer of the content. Must be either 'user' or 'model'. + /// Useful to set for multi-turn conversations, otherwise can be left blank or unset. + pub role: Option, + } + + #[derive(Debug, Deserialize, Serialize)] + pub enum Role { + User, + Model, + } + + /// A datatype containing media that is part of a multi-part [Content](Content) message. + /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data. + /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes. + #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "camelCase")] + pub enum Part { + Text(String), + InlineData(Blob), + FunctionCall(FunctionCall), + FunctionResponse(FunctionResponse), + FileData(FileData), + ExecutableCode(ExecutableCode), + CodeExecutionResult(CodeExecutionResult), + } + + /// Raw media bytes. + /// Text should not be sent as raw bytes, use the 'text' field. + #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct Blob { + /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg + /// If an unsupported MIME type is provided, an error will be returned. + pub mime_type: String, + /// Raw bytes for media formats. A base64-encoded string. + pub data: String, + } + + /// A predicted FunctionCall returned from the model that contains a string representing the + /// FunctionDeclaration.name with the arguments and their values. + /// #[derive(Debug, Deserialize, Serialize)] + #[derive(Debug, Deserialize, Serialize)] + pub struct FunctionCall { + /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores + /// and dashes, with a maximum length of 63. + pub name: String, + /// Optional. The function parameters and values in JSON object format. + pub args: Option>, + } + + /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name + /// and a structured JSON object containing any output from the function is used as context to the model. + /// This should contain the result of aFunctionCall made based on model prediction. + #[derive(Debug, Deserialize, Serialize)] + pub struct FunctionResponse { + /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, + /// with a maximum length of 63. + pub name: String, + /// The function response in JSON object format. + pub response: Option>, + } + + /// URI based data. + #[derive(Debug, Deserialize, Serialize)] + pub struct FileData { + /// Optional. The IANA standard MIME type of the source data. + pub mime_type: Option, + /// Required. URI. + pub file_uri: String, + } + + #[derive(Debug, Deserialize, Serialize)] + pub struct SafetyRating { + pub category: HarmCategory, + pub probability: HarmProbability, + } + + #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum HarmProbability { + HarmProbabilityUnspecified, + Negligible, + Low, + Medium, + High, + } + + #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum HarmCategory { + HarmCategoryUnspecified, + HarmCategoryDerogatory, + HarmCategoryToxicity, + HarmCategoryViolence, + HarmCategorySexually, + HarmCategoryMedical, + HarmCategoryDangerous, + HarmCategoryHarassment, + HarmCategoryHateSpeech, + HarmCategorySexuallyExplicit, + HarmCategoryDangerousContent, + HarmCategoryCivicIntegrity, + } + + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct UsageMetadata { + pub prompt_token_count: i32, + pub cached_content_token_count: i32, + pub candidates_token_count: i32, + pub total_token_count: i32, + } + + /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest). + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct PromptFeedback { + /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt. + pub block_reason: Option, + /// Ratings for safety of the prompt. There is at most one rating per category. + pub safety_ratings: Option>, + } + + /// Reason why a prompt was blocked by the model + #[derive(Debug, Deserialize)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum BlockReason { + /// Default value. This value is unused. + BlockReasonUnspecified, + /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it. + Safety, + /// Prompt was blocked due to unknown reasons. + Other, + /// Prompt was blocked due to the terms which are included from the terminology blocklist. + Blocklist, + /// Prompt was blocked due to prohibited content. + ProhibitedContent, + } + + #[derive(Debug, Deserialize)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum FinishReason { + /// Default value. This value is unused. + FinishReasonUnspecified, + /// Natural stop point of the model or provided stop sequence. + Stop, + /// The maximum number of tokens as specified in the request was reached. + MaxTokens, + /// The response candidate content was flagged for safety reasons. + Safety, + /// The response candidate content was flagged for recitation reasons. + Recitation, + /// The response candidate content was flagged for using an unsupported language. + Language, + /// Unknown reason. + Other, + /// Token generation stopped because the content contains forbidden terms. + Blocklist, + /// Token generation stopped for potentially containing prohibited content. + ProhibitedContent, + /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII). + Spii, + /// The function call generated by the model is invalid. + MalformedFunctionCall, + } + + #[derive(Debug, Deserialize)] + pub struct CitationMetadata { + pub citation_sources: Vec, + } + + #[derive(Debug, Deserialize)] + pub struct CitationSource { + pub uri: Option, + pub start_index: Option, + pub end_index: Option, + pub license: Option, + } + + #[derive(Debug, Deserialize)] + pub struct LogprobsResult { + pub top_candidate: Vec, + pub chosen_candidate: Vec, + } + + #[derive(Debug, Deserialize)] + pub struct TopCandidate { + pub candidates: Vec, + } + + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct LogProbCandidate { + pub token: String, + pub token_id: String, + pub log_probability: f64, + } + + /// Gemini API Configuration options for model generation and outputs. Not all parameters are + /// configurable for every model. https://ai.google.dev/api/generate-content#generationconfig + #[derive(Debug, Deserialize, Serialize)] + pub struct GenerationConfig { + /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop + /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response. + pub stop_sequences: Option>, + /// MIME type of the generated candidate text. Supported MIME types are: + /// - text/plain: (default) Text output + /// - application/json: JSON response in the response candidates. + /// - text/x.enum: ENUM as a string response in the response candidates. + /// Refer to the docs for a list of all supported text MIME types + pub response_mime_type: Option, + /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be + /// objects, primitives or arrays. If set, a compatible responseMimeType must also be set. Compatible MIME + /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details. + pub response_schema: Option, + /// Number of generated responses to return. Currently, this value can only be set to 1. If + /// unset, this will default to 1. + pub candidate_count: Option, + /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see + /// the Model.output_token_limit attribute of the Model returned from the getModel function. + pub max_output_tokens: Option, + /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature + /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0]. + pub temperature: Option, + /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and + /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most + /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while + /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value + /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty + /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests. + pub top_p: Option, + /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a + /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens. + /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is + /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates + /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests. + pub top_k: Option, + /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response. + /// This penalty is binary on/off and not dependant on the number of times the token is used (after the first). + /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use + /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will + /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary. + pub presence_penalty: Option, + /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been + /// seen in the respponse so far. A positive penalty will discourage the use of tokens that have already been + /// used, proportional to the number of times the token has been used: The more a token is used, the more + /// dificult it is for the model to use that token again increasing the vocabulary of responses. Caution: A + /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has + /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause + /// the model to repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...". + pub frequency_penalty: Option, + /// If true, export the logprobs results in response. + pub response_logprobs: Option, + /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in + /// [Candidate.logprobs_result]. + pub logprobs: Option, + } + + impl Default for GenerationConfig { + fn default() -> Self { + Self { + temperature: Some(1.0), + max_output_tokens: Some(4096), + stop_sequences: None, + response_mime_type: None, + response_schema: None, + candidate_count: None, + top_p: None, + top_k: None, + presence_penalty: None, + frequency_penalty: None, + response_logprobs: None, + logprobs: None, + } + } + } + /// The Schema object allows the definition of input and output data types. These types can be objects, but also + /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object. + /// https://ai.google.dev/api/caching#Schema + #[derive(Debug, Deserialize, Serialize)] + pub struct Schema { + pub r#type: String, + pub format: Option, + pub description: Option, + pub nullable: Option, + pub r#enum: Option>, + pub max_items: Option, + pub min_items: Option, + pub properties: Option>, + pub required: Option>, + pub items: Option>, + } + + impl TryFrom for Schema { + type Error = CompletionError; + + fn try_from(value: Value) -> Result { + if let Some(obj) = value.as_object() { + Ok(Schema { + r#type: obj + .get("type") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(), + format: obj.get("format").and_then(|v| v.as_str()).map(String::from), + description: obj + .get("description") + .and_then(|v| v.as_str()) + .map(String::from), + nullable: obj.get("nullable").and_then(|v| v.as_bool()), + r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }), + max_items: obj + .get("maxItems") + .and_then(|v| v.as_i64()) + .map(|v| v as i32), + min_items: obj + .get("minItems") + .and_then(|v| v.as_i64()) + .map(|v| v as i32), + properties: obj + .get("properties") + .and_then(|v| v.as_object()) + .map(|map| { + map.iter() + .filter_map(|(k, v)| { + v.clone().try_into().ok().map(|schema| (k.clone(), schema)) + }) + .collect() + }), + required: obj.get("required").and_then(|v| v.as_array()).map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }), + items: obj + .get("items") + .map(|v| Box::new(v.clone().try_into().unwrap())), + }) + } else { + Err(CompletionError::ResponseError( + "Expected a JSON object for Schema".into(), + )) + } + } + } + + #[derive(Debug, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct GenerateContentRequest { + pub contents: Vec, + pub tools: Option>, + pub tool_config: Option, + /// Optional. Configuration options for model generation and outputs. + pub generation_config: Option, + /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the + /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one + /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the + /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified + /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API + /// will use the default safety setting for that category. Harm categories: + /// - HARM_CATEGORY_HATE_SPEECH, + /// - HARM_CATEGORY_SEXUALLY_EXPLICIT + /// - HARM_CATEGORY_DANGEROUS_CONTENT + /// - HARM_CATEGORY_HARASSMENT + /// are supported. + /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance + /// to learn how to incorporate safety considerations in your AI applications. + pub safety_settings: Option>, + /// Optional. Developer set system instruction(s). Currently, text only. + /// https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest + pub system_instruction: Option, + // cachedContent: Optional + } + + #[derive(Debug, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct Tool { + pub function_declaration: FunctionDeclaration, + pub code_execution: Option, + } + + #[derive(Debug, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct FunctionDeclaration { + pub name: String, + pub description: String, + pub parameters: Option>, + } + + #[derive(Debug, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct ToolConfig { + pub schema: Option, + } + + #[derive(Debug, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct CodeExecution {} + + #[derive(Debug, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct SafetySetting { + pub category: HarmCategory, + pub threshold: HarmBlockThreshold, + } + + #[derive(Debug, Serialize)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum HarmBlockThreshold { + HarmBlockThresholdUnspecified, + BlockLowAndAbove, + BlockMediumAndAbove, + BlockOnlyHigh, + BlockNone, + Off, + } +} diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs index b0956286..0a866bf0 100644 --- a/rig-core/src/providers/gemini/embedding.rs +++ b/rig-core/src/providers/gemini/embedding.rs @@ -13,12 +13,87 @@ use super::{client::ApiResponse, Client}; pub const EMBEDDING_001: &str = "embedding-001"; /// `text-embedding-004` embedding model pub const EMBEDDING_004: &str = "text-embedding-004"; +#[derive(Clone)] +pub struct EmbeddingModel { + client: Client, + model: String, + ndims: Option, +} + +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: Option) -> Self { + Self { + client, + model: model.to_string(), + ndims, + } + } +} +impl embeddings::EmbeddingModel for EmbeddingModel { + const MAX_DOCUMENTS: usize = 1024; + + fn ndims(&self) -> usize { + match self.model.as_str() { + EMBEDDING_001 => 768, + EMBEDDING_004 => 1024, + _ => 0, // Default to 0 for unknown models + } + } + + async fn embed_documents( + &self, + documents: Vec, + ) -> Result, EmbeddingError> { + let mut request_body = json!({ + "model": format!("models/{}", self.model), + "content": { + "parts": documents.iter().map(|doc| json!({ "text": doc })).collect::>(), + }, + }); + + if let Some(ndims) = self.ndims { + request_body["output_dimensionality"] = json!(ndims); + } + + let response = self + .client + .post(&format!("/v1beta/models/{}:embedContent", self.model)) + .json(&request_body) + .send() + .await? + .error_for_status()? + .json::>() + .await?; + + match response { + ApiResponse::Ok(response) => { + let chunk_size = self.ndims.unwrap_or_else(|| self.ndims()); + Ok(documents + .into_iter() + .zip(response.embedding.values.chunks(chunk_size)) + .map(|(document, embedding)| embeddings::Embedding { + document, + vec: embedding.to_vec(), + }) + .collect()) + } + ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), + } + } +} + +// ================================================================= +// Gemini API Types +// ================================================================= +/// Rust Implementation of the Gemini Types from [Gemini API Reference](https://ai.google.dev/api/embeddings) #[allow(dead_code)] mod gemini_api_types { use serde::{Deserialize, Serialize}; use serde_json::Value; + use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode}; + #[derive(Serialize)] #[serde(rename_all = "camelCase")] pub struct EmbedContentRequest { @@ -96,44 +171,6 @@ mod gemini_api_types { mime_type: String, } - #[derive(Serialize)] - pub struct ExecutableCode { - /// The language of the code. - language: ExecutionLanguage, - /// The code to execute. - code: String, - } - - #[derive(Serialize)] - #[serde(rename_all = "SCREAMING_SNAKE_CASE")] - pub enum ExecutionLanguage { - /// Unspecified language. This value should not be used. - LanguageUnspecified, - /// Python >= 3.10, with numpy and simpy available. - Python, - } - - #[derive(Serialize)] - pub struct CodeExecutionResult { - /// Outcome of the code execution. - outcome: CodeExecutionOutcome, - /// Contains stdout when code execution is successful, stderr or other description otherwise. - output: Option, - } - - #[derive(Serialize)] - #[serde(rename_all = "SCREAMING_SNAKE_CASE")] - pub enum CodeExecutionOutcome { - /// Unspecified status. This value should not be used. - Unspecified, - /// Code execution completed successfully. - Ok, - /// Code execution finished but with a failure. stderr should contain the reason. - Failed, - /// Code execution ran for too long, and was cancelled. There may or may not be a partial output present. - DeadlineExceeded, - } - #[derive(Serialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum TaskType { @@ -165,72 +202,3 @@ mod gemini_api_types { pub values: Vec, } } -#[derive(Clone)] -pub struct EmbeddingModel { - client: Client, - model: String, - ndims: Option, -} - -impl EmbeddingModel { - pub fn new(client: Client, model: &str, ndims: Option) -> Self { - Self { - client, - model: model.to_string(), - ndims, - } - } -} - -impl embeddings::EmbeddingModel for EmbeddingModel { - const MAX_DOCUMENTS: usize = 1024; - - fn ndims(&self) -> usize { - match self.model.as_str() { - EMBEDDING_001 => 768, - EMBEDDING_004 => 1024, - _ => 0, // Default to 0 for unknown models - } - } - - async fn embed_documents( - &self, - documents: Vec, - ) -> Result, EmbeddingError> { - let mut request_body = json!({ - "model": format!("models/{}", self.model), - "content": { - "parts": documents.iter().map(|doc| json!({ "text": doc })).collect::>(), - }, - }); - - if let Some(ndims) = self.ndims { - request_body["output_dimensionality"] = json!(ndims); - } - - let response = self - .client - .post(&format!("/v1beta/models/{}:embedContent", self.model)) - .json(&request_body) - .send() - .await? - .error_for_status()? - .json::>() - .await?; - - match response { - ApiResponse::Ok(response) => { - let chunk_size = self.ndims.unwrap_or_else(|| self.ndims()); - Ok(documents - .into_iter() - .zip(response.embedding.values.chunks(chunk_size)) - .map(|(document, embedding)| embeddings::Embedding { - document, - vec: embedding.to_vec(), - }) - .collect()) - } - ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), - } - } -} diff --git a/rig-core/src/providers/gemini/mod.rs b/rig-core/src/providers/gemini/mod.rs index cd3d1b91..108d2c6f 100644 --- a/rig-core/src/providers/gemini/mod.rs +++ b/rig-core/src/providers/gemini/mod.rs @@ -12,5 +12,48 @@ pub mod client; pub mod completion; pub mod embedding; - pub use client::Client; + +pub mod gemini_api_types { + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize, Debug)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum ExecutionLanguage { + /// Unspecified language. This value should not be used. + LanguageUnspecified, + /// Python >= 3.10, with numpy and simpy available. + Python, + } + + /// Code generated by the model that is meant to be executed, and the result returned to the model. + /// Only generated when using the CodeExecution tool, in which the code will be automatically executed, + /// and a corresponding CodeExecutionResult will also be generated. + #[derive(Debug, Deserialize, Serialize)] + pub struct ExecutableCode { + /// Programming language of the code. + pub language: ExecutionLanguage, + /// The code to be executed. + pub code: String, + } + #[derive(Serialize, Deserialize, Debug)] + pub struct CodeExecutionResult { + /// Outcome of the code execution. + pub outcome: CodeExecutionOutcome, + /// Contains stdout when code execution is successful, stderr or other description otherwise. + pub output: Option, + } + + #[derive(Serialize, Deserialize, Debug)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum CodeExecutionOutcome { + /// Unspecified status. This value should not be used. + Unspecified, + /// Code execution completed successfully. + Ok, + /// Code execution finished but with a failure. stderr should contain the reason. + Failed, + /// Code execution ran for too long, and was cancelled. There may or may not be a partial output present. + DeadlineExceeded, + } +} From caee495dbe97fcfa1a92a783f5d24c82aff2f860 Mon Sep 17 00:00:00 2001 From: Mathieu Date: Fri, 1 Nov 2024 12:49:46 +0100 Subject: [PATCH 13/18] Merge branch 'main' into feat/model-provider/16-add-gemini-completion-embedding-models --- .github/workflows/ci.yaml | 28 +- Cargo.lock | 919 +++++++++++------- README.md | 12 +- img/rig_logo.svg | 6 +- img/rig_logo_dark.svg | 20 + rig-core/CHANGELOG.md | 25 + rig-core/Cargo.toml | 8 +- rig-core/examples/agent_with_loaders.rs | 38 + rig-core/examples/agent_with_tools.rs | 31 +- rig-core/examples/loaders.rs | 14 + rig-core/src/completion.rs | 6 +- rig-core/src/embeddings.rs | 2 +- rig-core/src/json_utils.rs | 20 +- rig-core/src/lib.rs | 9 +- rig-core/src/loaders/file.rs | 273 ++++++ rig-core/src/loaders/mod.rs | 9 + rig-core/src/loaders/pdf.rs | 456 +++++++++ rig-core/src/providers/anthropic/client.rs | 7 + .../src/providers/anthropic/completion.rs | 81 +- rig-core/src/providers/cohere.rs | 69 +- rig-core/src/providers/openai.rs | 58 +- rig-core/src/providers/perplexity.rs | 13 +- rig-core/tests/data/dummy.pdf | Bin 0 -> 13078 bytes rig-core/tests/data/pages.pdf | Bin 0 -> 12414 bytes rig-lancedb/CHANGELOG.md | 55 ++ rig-lancedb/Cargo.toml | 6 +- rig-lancedb/LICENSE | 7 + rig-lancedb/README.md | 2 + rig-mongodb/CHANGELOG.md | 11 + rig-mongodb/Cargo.toml | 4 +- rig-mongodb/README.md | 36 +- rig-mongodb/examples/vector_search_mongodb.rs | 6 +- 32 files changed, 1708 insertions(+), 523 deletions(-) create mode 100644 img/rig_logo_dark.svg create mode 100644 rig-core/examples/agent_with_loaders.rs create mode 100644 rig-core/examples/loaders.rs create mode 100644 rig-core/src/loaders/file.rs create mode 100644 rig-core/src/loaders/mod.rs create mode 100644 rig-core/src/loaders/pdf.rs create mode 100644 rig-core/tests/data/dummy.pdf create mode 100644 rig-core/tests/data/pages.pdf create mode 100644 rig-lancedb/CHANGELOG.md create mode 100644 rig-lancedb/LICENSE create mode 100644 rig-lancedb/README.md diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e00d68cd..d6ec1161 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -5,9 +5,6 @@ on: pull_request: branches: - main - push: - branches: - - main workflow_call: env: @@ -55,6 +52,8 @@ jobs: - name: Run clippy action uses: clechasseur/rs-clippy-check@v3 + with: + args: --all-features test: name: stable / test @@ -79,4 +78,25 @@ jobs: uses: actions-rs/cargo@v1 with: command: nextest - args: run --all-features \ No newline at end of file + args: run --all-features + + doc: + name: stable / doc + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust stable + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + components: rust-docs + + # Required to compile rig-lancedb + - name: Install Protoc + uses: arduino/setup-protoc@v3 + + - name: Run cargo doc + run: cargo doc --no-deps --all-features + env: + RUSTDOCFLAGS: -D warnings diff --git a/Cargo.lock b/Cargo.lock index 4b71fbe4..14b29d9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,21 +1,21 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" -version = "0.21.0" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" dependencies = [ "gimli", ] [[package]] -name = "adler" -version = "1.0.2" +name = "adler2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" [[package]] name = "ahash" @@ -61,6 +61,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anstyle" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" + [[package]] name = "anyhow" version = "1.0.89" @@ -218,7 +224,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.2.6", + "indexmap 2.6.0", "lexical-core", "num", "serde", @@ -260,7 +266,7 @@ version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e972cd1ff4a4ccd22f86d3e53e835c2ed92e0eea6a3e8eadb72b4f1ac802cf8" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", ] [[package]] @@ -294,6 +300,21 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "assert_fs" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7efdb1fdb47602827a342857666feb372712cbc64b414172bd6b167a02927674" +dependencies = [ + "anstyle", + "doc-comment", + "globwalk", + "predicates", + "predicates-core", + "predicates-tree", + "tempfile", +] + [[package]] name = "async-io" version = "1.13.0" @@ -340,18 +361,18 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] name = "async-trait" -version = "0.1.80" +version = "0.1.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] @@ -377,15 +398,15 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.5.6" +version = "1.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "848d7b9b605720989929279fa644ce8f244d0ce3146fcca5b70e4eb7b3c020fc" +checksum = "7198e6f03240fdceba36656d8be440297b6b82270325908c7381f37d826a74f6" dependencies = [ "aws-credential-types", "aws-runtime", @@ -400,7 +421,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "fastrand 2.1.0", + "fastrand 2.1.1", "hex", "http 0.2.12", "ring", @@ -438,7 +459,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "fastrand 2.1.0", + "fastrand 2.1.1", "http 0.2.12", "http-body 0.4.6", "once_cell", @@ -450,9 +471,9 @@ dependencies = [ [[package]] name = "aws-sdk-dynamodb" -version = "1.45.0" +version = "1.49.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7f3d9e807092149e3df266e3f4d9760dac439b90f82d8438e5b2c0bbe62007f" +checksum = "ab0ade000608877169533a54326badd6b5a707d2faf876cfc3976a7f9d7e5329" dependencies = [ "aws-credential-types", "aws-runtime", @@ -464,7 +485,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "fastrand 2.1.0", + "fastrand 2.1.1", "http 0.2.12", "once_cell", "regex-lite", @@ -473,9 +494,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.42.0" +version = "1.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27bf24cd0d389daa923e974b0e7c38daf308fc21e963c049f57980235017175e" +checksum = "e33ae899566f3d395cbf42858e433930682cc9c1889fa89318896082fef45efb" dependencies = [ "aws-credential-types", "aws-runtime", @@ -495,9 +516,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.43.0" +version = "1.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b43b3220f1c46ac0e9dcc0a97d94b93305dacb36d1dd393996300c6b9b74364" +checksum = "f39c09e199ebd96b9f860b0fce4b6625f211e064ad7c8693b72ecf7ef03881e0" dependencies = [ "aws-credential-types", "aws-runtime", @@ -517,9 +538,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.42.0" +version = "1.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1c46924fb1add65bba55636e12812cae2febf68c0f37361766f627ddcca91ce" +checksum = "3d95f93a98130389eb6233b9d615249e543f6c24a68ca1f109af9ca5164a8765" dependencies = [ "aws-credential-types", "aws-runtime", @@ -613,22 +634,22 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.1" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1ce695746394772e7000b39fe073095db6d45a862d0767dd5ad0ac0d7f8eb87" +checksum = "a065c0fe6fdbdf9f11817eb68582b2ab4aff9e9c39e986ae48f7ec576c6322db" dependencies = [ "aws-smithy-async", "aws-smithy-http", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", - "fastrand 2.1.0", + "fastrand 2.1.1", "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", "http-body 1.0.1", "httparse", - "hyper 0.14.28", + "hyper 0.14.30", "hyper-rustls 0.24.2", "once_cell", "pin-project-lite", @@ -657,9 +678,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.6" +version = "1.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03701449087215b5369c7ea17fef0dd5d24cb93439ec5af0c7615f58c3f22605" +checksum = "147100a7bea70fa20ef224a6bad700358305f5dc0f84649c53769761395b355b" dependencies = [ "base64-simd", "bytes", @@ -700,23 +721,23 @@ dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", "aws-smithy-types", - "rustc_version 0.4.0", + "rustc_version 0.4.1", "tracing", ] [[package]] name = "backtrace" -version = "0.3.71" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" dependencies = [ "addr2line", - "cc", "cfg-if", "libc", "miniz_oxide", "object", "rustc-demangle", + "windows-targets 0.52.6", ] [[package]] @@ -755,9 +776,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "bitpacking" @@ -791,15 +812,15 @@ dependencies = [ [[package]] name = "bson" -version = "2.10.0" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d43b38e074cc0de2957f10947e376a1d88b9c4dbab340b590800cc1b2e066b2" +checksum = "068208f2b6fcfa27a7f1ee37488d2bb8ba2640f68f5475d08e1d9130696aba59" dependencies = [ "ahash", "base64 0.13.1", "bitvec", "hex", - "indexmap 2.2.6", + "indexmap 2.6.0", "js-sys", "once_cell", "rand", @@ -810,6 +831,16 @@ dependencies = [ "uuid", ] +[[package]] +name = "bstr" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -836,9 +867,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.6.0" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" [[package]] name = "bytes-utils" @@ -883,13 +914,13 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.98" +version = "1.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f" +checksum = "2e80e3b6a3ab07840e1cae9b0666a63970dc28e8ed5ffbcdacbfc760c281bfc1" dependencies = [ "jobserver", "libc", - "once_cell", + "shlex", ] [[package]] @@ -916,7 +947,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -999,15 +1030,15 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.12" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" dependencies = [ "libc", ] @@ -1146,7 +1177,7 @@ dependencies = [ "proc-macro2", "quote", "strsim 0.11.1", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] @@ -1168,7 +1199,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core 0.20.10", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] @@ -1221,7 +1252,7 @@ dependencies = [ "glob", "half", "hashbrown 0.14.5", - "indexmap 2.2.6", + "indexmap 2.6.0", "itertools 0.12.1", "log", "num_cpus", @@ -1381,7 +1412,7 @@ dependencies = [ "datafusion-expr", "datafusion-physical-expr", "hashbrown 0.14.5", - "indexmap 2.2.6", + "indexmap 2.6.0", "itertools 0.12.1", "log", "paste", @@ -1410,7 +1441,7 @@ dependencies = [ "half", "hashbrown 0.14.5", "hex", - "indexmap 2.2.6", + "indexmap 2.6.0", "itertools 0.12.1", "log", "paste", @@ -1456,7 +1487,7 @@ dependencies = [ "futures", "half", "hashbrown 0.14.5", - "indexmap 2.2.6", + "indexmap 2.6.0", "itertools 0.12.1", "log", "once_cell", @@ -1526,17 +1557,23 @@ dependencies = [ [[package]] name = "derive_more" -version = "0.99.17" +version = "0.99.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" dependencies = [ "convert_case", "proc-macro2", "quote", - "rustc_version 0.4.0", - "syn 1.0.109", + "rustc_version 0.4.1", + "syn 2.0.79", ] +[[package]] +name = "difflib" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8" + [[package]] name = "digest" version = "0.10.7" @@ -1673,15 +1710,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" - -[[package]] -name = "finl_unicode" -version = "1.2.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" [[package]] name = "fixedbitset" @@ -1696,7 +1727,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8add37afff2d4ffa83bc748a70b4b1370984f6980768554182424ef71447c35f" dependencies = [ "bitflags 1.3.2", - "rustc_version 0.4.0", + "rustc_version 0.4.1", +] + +[[package]] +name = "flate2" +version = "1.0.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" +dependencies = [ + "crc32fast", + "miniz_oxide", ] [[package]] @@ -1705,6 +1746,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" + [[package]] name = "foreign-types" version = "0.3.2" @@ -1735,7 +1782,7 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7e180ac76c23b45e767bd7ae9579bc0bb458618c4bc71835926e098e61d15f8" dependencies = [ - "rustix 0.38.34", + "rustix 0.38.37", "windows-sys 0.52.0", ] @@ -1756,9 +1803,9 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -1771,9 +1818,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -1781,15 +1828,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -1798,9 +1845,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-lite" @@ -1819,32 +1866,32 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -1883,9 +1930,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.1" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" @@ -1893,6 +1940,30 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +[[package]] +name = "globset" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15f1ce686646e7f1e19bf7d5533fe443a45dbfb990e00629110797578b42fb19" +dependencies = [ + "aho-corasick", + "bstr", + "log", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "globwalk" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf760ebf69878d9fd8f110c89703d90ce35095324d1f1edcb595c63945ee757" +dependencies = [ + "bitflags 2.6.0", + "ignore", + "walkdir", +] + [[package]] name = "h2" version = "0.3.26" @@ -1905,7 +1976,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.2.6", + "indexmap 2.6.0", "slab", "tokio", "tokio-util", @@ -1924,7 +1995,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.1.0", - "indexmap 2.2.6", + "indexmap 2.6.0", "slab", "tokio", "tokio-util", @@ -1958,6 +2029,17 @@ dependencies = [ "allocator-api2", ] +[[package]] +name = "hashbrown" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] + [[package]] name = "heck" version = "0.4.1" @@ -2066,9 +2148,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.8.0" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" +checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" [[package]] name = "httpdate" @@ -2084,9 +2166,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.28" +version = "0.14.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" +checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" dependencies = [ "bytes", "futures-channel", @@ -2134,7 +2216,7 @@ checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http 0.2.12", - "hyper 0.14.28", + "hyper 0.14.30", "log", "rustls 0.21.12", "rustls-native-certs 0.6.3", @@ -2152,7 +2234,7 @@ dependencies = [ "http 1.1.0", "hyper 1.4.1", "hyper-util", - "rustls 0.23.13", + "rustls 0.23.14", "rustls-native-certs 0.8.0", "rustls-pki-types", "tokio", @@ -2167,7 +2249,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper 0.14.28", + "hyper 0.14.30", "native-tls", "tokio", "tokio-native-tls", @@ -2175,9 +2257,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.7" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" +checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b" dependencies = [ "bytes", "futures-channel", @@ -2188,7 +2270,6 @@ dependencies = [ "pin-project-lite", "socket2 0.5.7", "tokio", - "tower", "tower-service", "tracing", ] @@ -2204,9 +2285,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.60" +version = "0.1.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -2252,6 +2333,22 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "ignore" +version = "0.4.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d89fd380afde86567dfba715db065673989d6253f42b88179abd3eae47bda4b" +dependencies = [ + "crossbeam-deque", + "globset", + "log", + "memchr", + "regex-automata", + "same-file", + "walkdir", + "winapi-util", +] + [[package]] name = "indexmap" version = "1.9.3" @@ -2265,12 +2362,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.6" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown 0.14.5", + "hashbrown 0.15.0", "serde", ] @@ -2306,14 +2403,14 @@ dependencies = [ "socket2 0.5.7", "widestring", "windows-sys 0.48.0", - "winreg 0.50.0", + "winreg", ] [[package]] name = "ipnet" -version = "2.9.0" +version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" [[package]] name = "itertools" @@ -2350,9 +2447,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "0cb94a0ffd3f3ee755c20f7d8752f45cac88605a4dcf808abcff72873296ec7b" dependencies = [ "wasm-bindgen", ] @@ -2776,7 +2873,7 @@ dependencies = [ "regex", "serde", "serde_json", - "serde_with 3.9.0", + "serde_with 3.11.0", "snafu", "tokio", "url", @@ -2784,9 +2881,9 @@ dependencies = [ [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "levenshtein_automata" @@ -2860,9 +2957,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.155" +version = "0.2.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" [[package]] name = "libm" @@ -2876,7 +2973,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "libc", ] @@ -2910,17 +3007,37 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" + +[[package]] +name = "lopdf" +version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "c5c8ecfc6c72051981c0459f75ccc585e7ff67c70829560cda8e647882a9abff" +dependencies = [ + "chrono", + "encoding_rs", + "flate2", + "indexmap 2.6.0", + "itoa", + "log", + "md-5", + "nom", + "rangemap", + "rayon", + "time", + "weezl", +] [[package]] name = "lru" -version = "0.12.4" +version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37ee39891760e7d94734f6f63fedc29a2e4a152f836120753a72503f09fcf904" +checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown 0.14.5", + "hashbrown 0.15.0", ] [[package]] @@ -2984,9 +3101,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "memmap2" @@ -3011,11 +3128,11 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.7.3" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87dfd01fe195c66b572b37921ad8803d010623c0aca821bea2302239d155cdae" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" dependencies = [ - "adler", + "adler2", ] [[package]] @@ -3054,7 +3171,7 @@ dependencies = [ "once_cell", "parking_lot", "quanta", - "rustc_version 0.4.0", + "rustc_version 0.4.1", "scheduled-thread-pool", "skeptic", "smallvec", @@ -3125,11 +3242,10 @@ checksum = "2195bf6aa996a481483b29d62a7663eed3fe39600c460e323f8ff41e90bdd89b" [[package]] name = "native-tls" -version = "0.2.11" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" dependencies = [ - "lazy_static", "libc", "log", "openssl", @@ -3253,9 +3369,9 @@ dependencies = [ [[package]] name = "object" -version = "0.32.2" +version = "0.36.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" dependencies = [ "memchr", ] @@ -3279,9 +3395,9 @@ dependencies = [ "percent-encoding", "quick-xml", "rand", - "reqwest 0.12.5", + "reqwest 0.12.8", "ring", - "rustls-pemfile 2.1.3", + "rustls-pemfile 2.2.0", "serde", "serde_json", "snafu", @@ -3293,9 +3409,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "oneshot" @@ -3305,11 +3421,11 @@ checksum = "e296cf87e61c9cfc1a61c3c63a0f7f286ed4554e0e22be84e8a38e1d264a2a29" [[package]] name = "openssl" -version = "0.10.64" +version = "0.10.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" +checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "cfg-if", "foreign-types", "libc", @@ -3326,7 +3442,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] @@ -3337,9 +3453,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.102" +version = "0.9.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" +checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" dependencies = [ "cc", "libc", @@ -3355,9 +3471,9 @@ checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" [[package]] name = "ordered-float" -version = "4.2.0" +version = "4.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" +checksum = "44d501f1a72f71d3c063a6bbc8f7271fa73aa09fe5d6283b6571e2ed176a2537" dependencies = [ "num-traits", ] @@ -3391,9 +3507,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" [[package]] name = "parking_lot" -version = "0.12.2" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" dependencies = [ "lock_api", "parking_lot_core", @@ -3409,7 +3525,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -3461,7 +3577,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.2.6", + "indexmap 2.6.0", ] [[package]] @@ -3504,22 +3620,22 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.5" +version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" +checksum = "baf123a161dde1e524adf36f90bc5d8d3462824a9c43553ad07a8183161189ec" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.5" +version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" +checksum = "a4502d8515ca9f32f1fb543d987f63d95a14934883db45bdb48060b6b69257f8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] @@ -3536,9 +3652,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" [[package]] name = "polling" @@ -3564,9 +3680,39 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "predicates" +version = "3.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e9086cc7640c29a356d1a29fd134380bee9d8f79a17410aa76e7ad295f42c97" +dependencies = [ + "anstyle", + "difflib", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "ae8177bee8e75d6846599c6b9ff679ed51e882816914eec639944d7c9aa11931" + +[[package]] +name = "predicates-tree" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41b740d195ed3166cd147c8047ec98db0e22ec019eb8eeb76d343b795304fb13" +dependencies = [ + "predicates-core", + "termtree", +] [[package]] name = "prettyplease" @@ -3575,14 +3721,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba" dependencies = [ "proc-macro2", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] name = "proc-macro2" -version = "1.0.83" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b33eb56c327dec362a9e55b3ad14f9d2f0904fb5a5b03b513ab5465399e9f43" +checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" dependencies = [ "unicode-ident", ] @@ -3614,7 +3760,7 @@ dependencies = [ "prost", "prost-types", "regex", - "syn 2.0.77", + "syn 2.0.79", "tempfile", ] @@ -3628,7 +3774,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] @@ -3646,7 +3792,7 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57206b407293d2bcd3af849ce869d52068623f19e1b5ff8e8778e3309439682b" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "memchr", "unicase", ] @@ -3675,9 +3821,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quick-xml" -version = "0.36.1" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96a05e2e8efddfa51a84ca47cec303fac86c8541b686d37cac5efc0e094417bc" +checksum = "f7649a7b4df05aed9ea7ec6f628c67c9953a43869b8bc50929569b2999d443fe" dependencies = [ "memchr", "serde", @@ -3694,7 +3840,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash 2.0.0", - "rustls 0.23.13", + "rustls 0.23.14", "socket2 0.5.7", "thiserror", "tokio", @@ -3711,7 +3857,7 @@ dependencies = [ "rand", "ring", "rustc-hash 2.0.0", - "rustls 0.23.13", + "rustls 0.23.14", "slab", "thiserror", "tinyvec", @@ -3720,22 +3866,22 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.4" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bffec3605b73c6f1754535084a85229fa8a30f86014e6c81aeec4abb68b0285" +checksum = "4fe68c2e9e1a1234e218683dbdf9f9dfcb094113c5ac2b938dfcb9bab4c4140b" dependencies = [ "libc", "once_cell", "socket2 0.5.7", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -3823,11 +3969,11 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.1" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", ] [[package]] @@ -3843,9 +3989,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.6" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" dependencies = [ "aho-corasick", "memchr", @@ -3855,9 +4001,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", @@ -3872,9 +4018,9 @@ checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" @@ -3890,7 +4036,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.28", + "hyper 0.14.30", "hyper-tls", "ipnet", "js-sys", @@ -3913,14 +4059,14 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "winreg 0.50.0", + "winreg", ] [[package]] name = "reqwest" -version = "0.12.5" +version = "0.12.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37" +checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b" dependencies = [ "base64 0.22.1", "bytes", @@ -3941,9 +4087,9 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.13", - "rustls-native-certs 0.7.3", - "rustls-pemfile 2.1.3", + "rustls 0.23.14", + "rustls-native-certs 0.8.0", + "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", "serde_json", @@ -3958,7 +4104,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "winreg 0.52.0", + "windows-registry", ] [[package]] @@ -3973,10 +4119,13 @@ dependencies = [ [[package]] name = "rig-core" -version = "0.2.1" +version = "0.3.0" dependencies = [ "anyhow", + "assert_fs", "futures", + "glob", + "lopdf", "ordered-float", "reqwest 0.11.27", "schemars", @@ -4004,7 +4153,7 @@ dependencies = [ [[package]] name = "rig-mongodb" -version = "0.1.2" +version = "0.1.3" dependencies = [ "anyhow", "futures", @@ -4080,9 +4229,9 @@ dependencies = [ [[package]] name = "rustc_version" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" dependencies = [ "semver 1.0.23", ] @@ -4113,11 +4262,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.34" +version = "0.38.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "errno", "libc", "linux-raw-sys 0.4.14", @@ -4138,9 +4287,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.13" +version = "0.23.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8" +checksum = "415d9944693cb90382053259f89fbb077ea730ad7273047ec63b19bc9b160ba8" dependencies = [ "once_cell", "ring", @@ -4162,19 +4311,6 @@ dependencies = [ "security-framework", ] -[[package]] -name = "rustls-native-certs" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" -dependencies = [ - "openssl-probe", - "rustls-pemfile 2.1.3", - "rustls-pki-types", - "schannel", - "security-framework", -] - [[package]] name = "rustls-native-certs" version = "0.8.0" @@ -4182,7 +4318,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" dependencies = [ "openssl-probe", - "rustls-pemfile 2.1.3", + "rustls-pemfile 2.2.0", "rustls-pki-types", "schannel", "security-framework", @@ -4199,19 +4335,18 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.1.3" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" dependencies = [ - "base64 0.22.1", "rustls-pki-types", ] [[package]] name = "rustls-pki-types" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" +checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55" [[package]] name = "rustls-webpki" @@ -4257,11 +4392,11 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.23" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4275,9 +4410,9 @@ dependencies = [ [[package]] name = "schemars" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0218ceea14babe24a4a5836f86ade86c1effbc198164e619194cb5069187e29" +checksum = "09c024468a378b7e36765cd36702b7a90cc3cba11654f6685c8f233408e89e92" dependencies = [ "dyn-clone", "schemars_derive", @@ -4287,14 +4422,14 @@ dependencies = [ [[package]] name = "schemars_derive" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ed5a1ccce8ff962e31a165d41f6e2a2dd1245099dc4d594f5574a86cd90f4d3" +checksum = "b1eee588578aff73f856ab961cd2f79e36bc45d7ded33a7562adba4667aecc0e" dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] @@ -4315,11 +4450,11 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "core-foundation", "core-foundation-sys", "libc", @@ -4328,9 +4463,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.0" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" dependencies = [ "core-foundation-sys", "libc", @@ -4371,9 +4506,9 @@ dependencies = [ [[package]] name = "serde_bytes" -version = "0.11.14" +version = "0.11.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b8497c313fd43ab992087548117643f6fcd935cbf36f176ffda0aacf9591734" +checksum = "387cc504cb06bb40a96c8e04e951fe01854cf6bc921053c954e4a606d9675c6a" dependencies = [ "serde", ] @@ -4386,7 +4521,7 @@ checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] @@ -4397,7 +4532,7 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] @@ -4406,7 +4541,7 @@ version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.6.0", "itoa", "memchr", "ryu", @@ -4437,19 +4572,19 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.9.0" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cecfa94848272156ea67b2b1a53f20fc7bc638c4a46d2f8abde08f05f4b857" +checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" dependencies = [ "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.2.6", + "indexmap 2.6.0", "serde", "serde_derive", "serde_json", - "serde_with_macros 3.9.0", + "serde_with_macros 3.11.0", "time", ] @@ -4467,14 +4602,14 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.9.0" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8fee4991ef4f274617a51ad4af30519438dacb2f56ac773b08a1922ff743350" +checksum = "9d846214a9854ef724f3da161b426242d8de7c1fc7de2f89bb1efcb154dca79d" dependencies = [ "darling 0.20.10", "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] @@ -4517,6 +4652,12 @@ dependencies = [ "dirs", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -4637,7 +4778,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] @@ -4666,13 +4807,13 @@ checksum = "e51f1e89f093f99e7432c491c382b88a6860a5adbe6bf02574bf0a08efff1978" [[package]] name = "stringprep" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb41d74e231a107a1b4ee36bd1214b11285b77768d2e3824aedafa988fd36ee6" +checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" dependencies = [ - "finl_unicode", "unicode-bidi", "unicode-normalization", + "unicode-properties", ] [[package]] @@ -4706,14 +4847,14 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] name = "subtle" -version = "2.5.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" @@ -4728,9 +4869,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.77" +version = "2.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" +checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" dependencies = [ "proc-macro2", "quote", @@ -4748,6 +4889,9 @@ name = "sync_wrapper" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +dependencies = [ + "futures-core", +] [[package]] name = "system-configuration" @@ -4931,34 +5075,41 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.10.1" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" dependencies = [ "cfg-if", - "fastrand 2.1.0", - "rustix 0.38.34", - "windows-sys 0.52.0", + "fastrand 2.1.1", + "once_cell", + "rustix 0.38.37", + "windows-sys 0.59.0", ] +[[package]] +name = "termtree" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" + [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] @@ -5013,9 +5164,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.6.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" dependencies = [ "tinyvec_macros", ] @@ -5052,7 +5203,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] @@ -5081,7 +5232,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.13", + "rustls 0.23.14", "rustls-pki-types", "tokio", ] @@ -5099,9 +5250,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.11" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" dependencies = [ "bytes", "futures-core", @@ -5111,32 +5262,11 @@ dependencies = [ "tokio", ] -[[package]] -name = "tower" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" -dependencies = [ - "futures-core", - "futures-util", - "pin-project", - "pin-project-lite", - "tokio", - "tower-layer", - "tower-service", -] - -[[package]] -name = "tower-layer" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" - [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" @@ -5157,7 +5287,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] @@ -5197,9 +5327,9 @@ dependencies = [ [[package]] name = "triomphe" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6631e42e10b40c0690bf92f404ebcfe6e1fdb480391d15f17cc8e96eeed5369" +checksum = "ef8f7726da4807b58ea5c96fdc122f80702030edc33b35aff9190a51148ccc85" [[package]] name = "trust-dns-proto" @@ -5290,25 +5420,31 @@ dependencies = [ [[package]] name = "unicode-bidi" -version = "0.3.15" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" +checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "unicode-normalization" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-properties" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" + [[package]] name = "unicode-segmentation" version = "1.12.0" @@ -5317,9 +5453,9 @@ checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-width" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "untrusted" @@ -5329,9 +5465,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.0" +version = "2.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" dependencies = [ "form_urlencoded", "idna 0.5.0", @@ -5352,9 +5488,9 @@ checksum = "7fcfc827f90e53a02eaef5e535ee14266c1d569214c6aa70133a624d8a3164ba" [[package]] name = "uuid" -version = "1.8.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" dependencies = [ "getrandom", "serde", @@ -5374,9 +5510,9 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "vsimd" @@ -5417,34 +5553,35 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "ef073ced962d62984fb38a36e5fdc1a2b23c9e0e1fa0689bb97afa4202ef6887" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "c4bfab14ef75323f4eb75fa52ee0a3fb59611977fd3240da19b2cf36ff85030e" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.42" +version = "0.4.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" +checksum = "65471f79c1022ffa5291d33520cbbb53b7687b01c2f8e83b57d102eed7ed479d" dependencies = [ "cfg-if", "js-sys", @@ -5454,9 +5591,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "a7bec9830f60924d9ceb3ef99d55c155be8afa76954edffbb5936ff4509474e7" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -5464,28 +5601,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "4c74f6e152a76a2ad448e223b0fc0b6b5747649c3d769cc6bf45737bf97d0ed6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "a42f6c679374623f295a8623adfe63d9284091245c3504bde47c17a3ce2777d9" [[package]] name = "wasm-streams" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" +checksum = "4e072d4e72f700fb3443d8fe94a39315df013eef1104903cdb0a2abd322bbecd" dependencies = [ "futures-util", "js-sys", @@ -5496,9 +5633,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "44188d185b5bdcae1052d08bcbcf9091a5524038d4572cc4f4f2bb9d5554ddd9" dependencies = [ "js-sys", "wasm-bindgen", @@ -5510,6 +5647,12 @@ version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" +[[package]] +name = "weezl" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" + [[package]] name = "widestring" version = "1.1.0" @@ -5538,7 +5681,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -5553,7 +5696,37 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-registry" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +dependencies = [ + "windows-result", + "windows-strings", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result", + "windows-targets 0.52.6", ] [[package]] @@ -5571,7 +5744,16 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", ] [[package]] @@ -5591,18 +5773,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -5613,9 +5795,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -5625,9 +5807,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -5637,15 +5819,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -5655,9 +5837,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -5667,9 +5849,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -5679,9 +5861,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -5691,9 +5873,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winreg" @@ -5705,16 +5887,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "winreg" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a277a57398d4bfa075df44f501a17cfdf8542d224f0d36095a2adc7aee4ef0a5" -dependencies = [ - "cfg-if", - "windows-sys 0.48.0", -] - [[package]] name = "wyz" version = "0.5.1" @@ -5732,22 +5904,23 @@ checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" [[package]] name = "zerocopy" -version = "0.7.34" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ + "byteorder", "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.34" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.79", ] [[package]] diff --git a/README.md b/README.md index c3f8aaee..7a65fa35 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,10 @@

-Rig Logo
+ + + + Rig logo + +
  @@ -12,7 +17,7 @@

  -> [!WARNING] +> [!WARNING] > Here be dragons! Rig is **alpha** software and **will** contain breaking changes as it evolves. We'll annotate them and highlight migration paths as we encounter them. @@ -70,7 +75,10 @@ or just `full` to enable all features (`cargo add tokio --features macros,rt-mul Rig supports the following LLM providers natively: - OpenAI - Cohere +- Anthropic +- Perplexity - Google Gemini Additionally, Rig currently has the following integration sub-libraries: - MongoDB vector store: `rig-mongodb` +- LanceDB vector store: `rig-lancedb` diff --git a/img/rig_logo.svg b/img/rig_logo.svg index be08ccc0..b4241dcd 100644 --- a/img/rig_logo.svg +++ b/img/rig_logo.svg @@ -3,8 +3,8 @@ @@ -17,4 +17,4 @@ - \ No newline at end of file + diff --git a/img/rig_logo_dark.svg b/img/rig_logo_dark.svg new file mode 100644 index 00000000..f5ef032b --- /dev/null +++ b/img/rig_logo_dark.svg @@ -0,0 +1,20 @@ + + + + + + + + + + + + diff --git a/rig-core/CHANGELOG.md b/rig-core/CHANGELOG.md index ae6e79da..0d5fa427 100644 --- a/rig-core/CHANGELOG.md +++ b/rig-core/CHANGELOG.md @@ -7,6 +7,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.3.0](https://github.com/0xPlaygrounds/rig/compare/rig-core-v0.2.1...rig-core-v0.3.0) - 2024-10-24 + +### Added + +- Generalize `EmbeddingModel::embed_documents` with `IntoIterator` +- Add `from_env` constructor to Cohere and Anthropic clients +- Small optimization to serde_json object merging +- Add better error handling for provider clients + +### Fixed + +- Bad Anthropic request/response handling +- *(vector-index)* In memory vector store index incorrect search + +### Other + +- Made internal `json_utils` module private +- Update lib docs +- Made CompletionRequest helper method private to crate +- lint + fmt +- Simplify `agent_with_tools` example +- Fix docstring links +- Add nextest test runner to CI +- Merge pull request [#42](https://github.com/0xPlaygrounds/rig/pull/42) from 0xPlaygrounds/refactor(vector-store)/update-vector-store-index-trait + ## [0.2.1](https://github.com/0xPlaygrounds/rig/compare/rig-core-v0.2.0...rig-core-v0.2.1) - 2024-10-01 ### Fixed diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index f0b2182b..4b1a3a01 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-core" -version = "0.2.1" +version = "0.3.0" edition = "2021" license = "MIT" readme = "README.md" @@ -23,8 +23,14 @@ futures = "0.3.29" ordered-float = "4.2.0" schemars = "0.8.16" thiserror = "1.0.61" +glob = "0.3.1" +lopdf = { version = "0.34.0", optional = true } [dev-dependencies] anyhow = "1.0.75" +assert_fs = "1.1.2" tokio = { version = "1.34.0", features = ["full"] } tracing-subscriber = "0.3.18" + +[features] +pdf = ["dep:lopdf"] diff --git a/rig-core/examples/agent_with_loaders.rs b/rig-core/examples/agent_with_loaders.rs new file mode 100644 index 00000000..9ddf22de --- /dev/null +++ b/rig-core/examples/agent_with_loaders.rs @@ -0,0 +1,38 @@ +use std::env; + +use rig::{ + agent::AgentBuilder, + completion::Prompt, + loaders::FileLoader, + providers::openai::{self, GPT_4O}, +}; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + let openai_client = + openai::Client::new(&env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set")); + + let model = openai_client.completion_model(GPT_4O); + + // Load in all the rust examples + let examples = FileLoader::with_glob("rig-core/examples/*.rs")? + .read_with_path() + .ignore_errors() + .into_iter(); + + // Create an agent with multiple context documents + let agent = examples + .fold(AgentBuilder::new(model), |builder, (path, content)| { + builder.context(format!("Rust Example {:?}:\n{}", path, content).as_str()) + }) + .build(); + + // Prompt the agent and print the response + let response = agent + .prompt("Which rust example is best suited for the operation 1 + 2") + .await?; + + println!("{}", response); + + Ok(()) +} diff --git a/rig-core/examples/agent_with_tools.rs b/rig-core/examples/agent_with_tools.rs index 3356a40a..915c4d52 100644 --- a/rig-core/examples/agent_with_tools.rs +++ b/rig-core/examples/agent_with_tools.rs @@ -6,7 +6,6 @@ use rig::{ }; use serde::{Deserialize, Serialize}; use serde_json::json; -use std::env; #[derive(Deserialize)] struct OperationArgs { @@ -92,25 +91,13 @@ impl Tool for Subtract { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Create OpenAI client - let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); - let openai_client = providers::openai::Client::new(&openai_api_key); + let openai_client = providers::openai::Client::from_env(); // Create agent with a single context prompt and two tools - let gpt4_calculator_agent = openai_client - .agent("gpt-4") - .context("You are a calculator here to help the user perform arithmetic operations.") - .tool(Adder) - .tool(Subtract) - .build(); - - // Create OpenAI client - let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set"); - let cohere_client = providers::cohere::Client::new(&cohere_api_key); - - // Create agent with a single context prompt and two tools - let coral_calculator_agent = cohere_client - .agent("command-r") - .preamble("You are a calculator here to help the user perform arithmetic operations.") + let calculator_agent = openai_client + .agent(providers::openai::GPT_4O) + .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.") + .max_tokens(1024) .tool(Adder) .tool(Subtract) .build(); @@ -118,12 +105,8 @@ async fn main() -> Result<(), anyhow::Error> { // Prompt the agent and print the response println!("Calculate 2 - 5"); println!( - "GPT-4: {}", - gpt4_calculator_agent.prompt("Calculate 2 - 5").await? - ); - println!( - "Coral: {}", - coral_calculator_agent.prompt("Calculate 2 - 5").await? + "Calculator Agent: {}", + calculator_agent.prompt("Calculate 2 - 5").await? ); Ok(()) diff --git a/rig-core/examples/loaders.rs b/rig-core/examples/loaders.rs new file mode 100644 index 00000000..e4670a71 --- /dev/null +++ b/rig-core/examples/loaders.rs @@ -0,0 +1,14 @@ +use rig::loaders::FileLoader; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + FileLoader::with_glob("cargo.toml")? + .read() + .into_iter() + .for_each(|result| match result { + Ok(content) => println!("{}", content), + Err(e) => eprintln!("Error reading file: {}", e), + }); + + Ok(()) +} diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index 4383e561..e766fb27 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -266,7 +266,7 @@ pub struct CompletionRequest { } impl CompletionRequest { - pub fn prompt_with_context(&self) -> String { + pub(crate) fn prompt_with_context(&self) -> String { if !self.documents.is_empty() { format!( "\n{}\n\n{}", @@ -439,14 +439,14 @@ impl CompletionRequestBuilder { } /// Sets the max tokens for the completion request. - /// Only required for: [ Anthropic ] + /// Note: This is required if using Anthropic pub fn max_tokens(mut self, max_tokens: u64) -> Self { self.max_tokens = Some(max_tokens); self } /// Sets the max tokens for the completion request. - /// Only required for: [ Anthropic ] + /// Note: This is required if using Anthropic pub fn max_tokens_opt(mut self, max_tokens: Option) -> Self { self.max_tokens = max_tokens; self diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index 7a34e822..eaced08b 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -97,7 +97,7 @@ pub trait EmbeddingModel: Clone + Sync + Send { /// Embed multiple documents in a single request fn embed_documents( &self, - documents: Vec, + documents: impl IntoIterator + Send, ) -> impl std::future::Future, EmbeddingError>> + Send; } diff --git a/rig-core/src/json_utils.rs b/rig-core/src/json_utils.rs index 09dee92a..25d3ef91 100644 --- a/rig-core/src/json_utils.rs +++ b/rig-core/src/json_utils.rs @@ -1,11 +1,19 @@ pub fn merge(a: serde_json::Value, b: serde_json::Value) -> serde_json::Value { - match (a.clone(), b) { - (serde_json::Value::Object(mut a), serde_json::Value::Object(b)) => { - b.into_iter().for_each(|(key, value)| { - a.insert(key.clone(), value.clone()); + match (a, b) { + (serde_json::Value::Object(mut a_map), serde_json::Value::Object(b_map)) => { + b_map.into_iter().for_each(|(key, value)| { + a_map.insert(key, value); }); - serde_json::Value::Object(a) + serde_json::Value::Object(a_map) } - _ => a, + (a, _) => a, + } +} + +pub fn merge_inplace(a: &mut serde_json::Value, b: serde_json::Value) { + if let (serde_json::Value::Object(a_map), serde_json::Value::Object(b_map)) = (a, b) { + b_map.into_iter().for_each(|(key, value)| { + a_map.insert(key, value); + }); } } diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 86c25209..5f7f021d 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -54,24 +54,27 @@ //! Rig provides a common interface for working with vector stores and indexes. Specifically, the library //! provides the [VectorStore](crate::vector_store::VectorStore) and [VectorStoreIndex](crate::vector_store::VectorStoreIndex) //! traits, which can be implemented to define vector stores and indices respectively. -//! Those can then be used as the knowledgebase for a [RagAgent](crate::rag::RagAgent), or +//! Those can then be used as the knowledgebase for a RAG enabled [Agent](crate::agent::Agent), or //! as a source of context documents in a custom architecture that use multiple LLMs or agents. //! //! # Integrations //! Rig natively supports the following completion and embedding model providers: //! - OpenAI //! - Cohere +//! - Anthropic +//! - Perplexity //! //! Rig currently has the following integration companion crates: //! - `rig-mongodb`: Vector store implementation for MongoDB -//! +//! - `rig-lancedb`: Vector store implementation for LanceDB pub mod agent; pub mod cli_chatbot; pub mod completion; pub mod embeddings; pub mod extractor; -pub mod json_utils; +pub(crate) mod json_utils; +pub mod loaders; pub mod providers; pub mod tool; pub mod vector_store; diff --git a/rig-core/src/loaders/file.rs b/rig-core/src/loaders/file.rs new file mode 100644 index 00000000..17c2f1f3 --- /dev/null +++ b/rig-core/src/loaders/file.rs @@ -0,0 +1,273 @@ +use std::{fs, path::PathBuf}; + +use glob::glob; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum FileLoaderError { + #[error("Invalid glob pattern: {0}")] + InvalidGlobPattern(String), + + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), + + #[error("Pattern error: {0}")] + PatternError(#[from] glob::PatternError), + + #[error("Glob error: {0}")] + GlobError(#[from] glob::GlobError), +} + +// ================================================================ +// Implementing Readable trait for reading file contents +// ================================================================ +pub(crate) trait Readable { + fn read(self) -> Result; + fn read_with_path(self) -> Result<(PathBuf, String), FileLoaderError>; +} + +impl<'a> FileLoader<'a, PathBuf> { + pub fn read(self) -> FileLoader<'a, Result> { + FileLoader { + iterator: Box::new(self.iterator.map(|res| res.read())), + } + } + pub fn read_with_path(self) -> FileLoader<'a, Result<(PathBuf, String), FileLoaderError>> { + FileLoader { + iterator: Box::new(self.iterator.map(|res| res.read_with_path())), + } + } +} + +impl Readable for PathBuf { + fn read(self) -> Result { + fs::read_to_string(self).map_err(FileLoaderError::IoError) + } + fn read_with_path(self) -> Result<(PathBuf, String), FileLoaderError> { + let contents = fs::read_to_string(&self); + Ok((self, contents?)) + } +} +impl Readable for Result { + fn read(self) -> Result { + self.map(|t| t.read())? + } + fn read_with_path(self) -> Result<(PathBuf, String), FileLoaderError> { + self.map(|t| t.read_with_path())? + } +} + +// ================================================================ +// FileLoader definitions and implementations +// ================================================================ + +/// [FileLoader] is a utility for loading files from the filesystem using glob patterns or directory +/// paths. It provides methods to read file contents and handle errors gracefully. +/// +/// # Errors +/// +/// This module defines a custom error type [FileLoaderError] which can represent various errors +/// that might occur during file loading operations, such as invalid glob patterns, IO errors, and +/// glob errors. +/// +/// # Example Usage +/// +/// ```rust +/// use rig:loaders::FileLoader; +/// +/// fn main() -> Result<(), Box> { +/// // Create a FileLoader using a glob pattern +/// let loader = FileLoader::with_glob("path/to/files/*.txt")?; +/// +/// // Read file contents, ignoring any errors +/// let contents: Vec = loader +/// .read() +/// .ignore_errors() +/// +/// for content in contents { +/// println!("{}", content); +/// } +/// +/// Ok(()) +/// } +/// ``` +/// +/// [FileLoader] uses strict typing between the iterator methods to ensure that transitions between +/// different implementations of the loaders and it's methods are handled properly by the compiler. +pub struct FileLoader<'a, T> { + iterator: Box + 'a>, +} + +impl<'a> FileLoader<'a, Result> { + /// Reads the contents of the files within the iterator returned by [FileLoader::with_glob] or + /// [FileLoader::with_dir]. + /// + /// # Example + /// Read files in directory "files/*.txt" and print the content for each file + /// + /// ```rust + /// let content = FileLoader::with_glob(...)?.read(); + /// for result in content { + /// match result { + /// Ok(content) => println!("{}", content), + /// Err(e) => eprintln!("Error reading file: {}", e), + /// } + /// } + /// ``` + pub fn read(self) -> FileLoader<'a, Result> { + FileLoader { + iterator: Box::new(self.iterator.map(|res| res.read())), + } + } + /// Reads the contents of the files within the iterator returned by [FileLoader::with_glob] or + /// [FileLoader::with_dir] and returns the path along with the content. + /// + /// # Example + /// Read files in directory "files/*.txt" and print the content for cooresponding path for each + /// file. + /// + /// ```rust + /// let content = FileLoader::with_glob("files/*.txt")?.read(); + /// for (path, result) in content { + /// match result { + /// Ok((path, content)) => println!("{:?} {}", path, content), + /// Err(e) => eprintln!("Error reading file: {}", e), + /// } + /// } + /// ``` + pub fn read_with_path(self) -> FileLoader<'a, Result<(PathBuf, String), FileLoaderError>> { + FileLoader { + iterator: Box::new(self.iterator.map(|res| res.read_with_path())), + } + } +} + +impl<'a, T: 'a> FileLoader<'a, Result> { + /// Ignores errors in the iterator, returning only successful results. This can be used on any + /// [FileLoader] state of iterator whose items are results. + /// + /// # Example + /// Read files in directory "files/*.txt" and ignore errors from unreadable files. + /// + /// ```rust + /// let content = FileLoader::with_glob("files/*.txt")?.read().ignore_errors(); + /// for result in content { + /// println!("{}", content) + /// } + /// ``` + pub fn ignore_errors(self) -> FileLoader<'a, T> { + FileLoader { + iterator: Box::new(self.iterator.filter_map(|res| res.ok())), + } + } +} + +impl<'a> FileLoader<'a, Result> { + /// Creates a new [FileLoader] using a glob pattern to match files. + /// + /// # Example + /// Create a [FileLoader] for all `.txt` files that match the glob "files/*.txt". + /// + /// ```rust + /// let loader = FileLoader::with_glob("files/*.txt")?; + /// ``` + pub fn with_glob( + pattern: &str, + ) -> Result>, FileLoaderError> { + let paths = glob(pattern)?; + Ok(FileLoader { + iterator: Box::new( + paths + .into_iter() + .map(|path| path.map_err(FileLoaderError::GlobError)), + ), + }) + } + + /// Creates a new [FileLoader] on all files within a directory. + /// + /// # Example + /// Create a [FileLoader] for all files that are in the directory "files" (ignores subdirectories). + /// + /// ```rust + /// let loader = FileLoader::with_dir("files")?; + /// ``` + pub fn with_dir( + directory: &str, + ) -> Result>, FileLoaderError> { + Ok(FileLoader { + iterator: Box::new(fs::read_dir(directory)?.filter_map(|entry| { + let path = entry.ok()?.path(); + if path.is_file() { + Some(Ok(path)) + } else { + None + } + })), + }) + } +} + +// ================================================================ +// Iterators for FileLoader +// ================================================================ + +pub struct IntoIter<'a, T> { + iterator: Box + 'a>, +} + +impl<'a, T> IntoIterator for FileLoader<'a, T> { + type Item = T; + type IntoIter = IntoIter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + IntoIter { + iterator: self.iterator, + } + } +} + +impl<'a, T> Iterator for IntoIter<'a, T> { + type Item = T; + + fn next(&mut self) -> Option { + self.iterator.next() + } +} + +#[cfg(test)] +mod tests { + use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild}; + + use super::FileLoader; + + #[test] + fn test_file_loader() { + let temp = assert_fs::TempDir::new().expect("Failed to create temp dir"); + let foo_file = temp.child("foo.txt"); + let bar_file = temp.child("bar.txt"); + + foo_file.touch().expect("Failed to create foo.txt"); + bar_file.touch().expect("Failed to create bar.txt"); + + foo_file.write_str("foo").expect("Failed to write to foo"); + bar_file.write_str("bar").expect("Failed to write to bar"); + + let glob = temp.path().to_string_lossy().to_string() + "/*.txt"; + + let loader = FileLoader::with_glob(&glob).unwrap(); + let mut actual = loader + .ignore_errors() + .read() + .ignore_errors() + .into_iter() + .collect::>(); + let mut expected = vec!["foo".to_string(), "bar".to_string()]; + + actual.sort(); + expected.sort(); + + assert!(!actual.is_empty()); + assert!(expected == actual) + } +} diff --git a/rig-core/src/loaders/mod.rs b/rig-core/src/loaders/mod.rs new file mode 100644 index 00000000..ce87ee0f --- /dev/null +++ b/rig-core/src/loaders/mod.rs @@ -0,0 +1,9 @@ +pub mod file; + +pub use file::FileLoader; + +#[cfg(feature = "pdf")] +pub mod pdf; + +#[cfg(feature = "pdf")] +pub use pdf::PdfFileLoader; diff --git a/rig-core/src/loaders/pdf.rs b/rig-core/src/loaders/pdf.rs new file mode 100644 index 00000000..ea18e4e6 --- /dev/null +++ b/rig-core/src/loaders/pdf.rs @@ -0,0 +1,456 @@ +use std::{fs, path::PathBuf}; + +use glob::glob; +use lopdf::{Document, Error as LopdfError}; +use thiserror::Error; + +use super::file::FileLoaderError; + +#[derive(Error, Debug)] +pub enum PdfLoaderError { + #[error("{0}")] + FileLoaderError(#[from] FileLoaderError), + + #[error("UTF-8 conversion error: {0}")] + FromUtf8Error(#[from] std::string::FromUtf8Error), + + #[error("IO error: {0}")] + PdfError(#[from] LopdfError), +} + +// ================================================================ +// Implementing Loadable trait for loading pdfs +// ================================================================ + +pub(crate) trait Loadable { + fn load(self) -> Result; + fn load_with_path(self) -> Result<(PathBuf, Document), PdfLoaderError>; +} + +impl Loadable for PathBuf { + fn load(self) -> Result { + Document::load(self).map_err(PdfLoaderError::PdfError) + } + fn load_with_path(self) -> Result<(PathBuf, Document), PdfLoaderError> { + let contents = Document::load(&self); + Ok((self, contents?)) + } +} +impl Loadable for Result { + fn load(self) -> Result { + self.map(|t| t.load())? + } + fn load_with_path(self) -> Result<(PathBuf, Document), PdfLoaderError> { + self.map(|t| t.load_with_path())? + } +} + +// ================================================================ +// PdfFileLoader definitions and implementations +// ================================================================ + +/// [PdfFileLoader] is a utility for loading pdf files from the filesystem using glob patterns or +/// directory paths. It provides methods to read file contents and handle errors gracefully. +/// +/// # Errors +/// +/// This module defines a custom error type [PdfLoaderError] which can represent various errors +/// that might occur during file loading operations, such as any [FileLoaderError] alongside +/// specific PDF-related errors. +/// +/// # Example Usage +/// +/// ```rust +/// use rig:loaders::PdfileLoader; +/// +/// fn main() -> Result<(), Box> { +/// // Create a FileLoader using a glob pattern +/// let loader = PdfFileLoader::with_glob("tests/data/*.pdf")?; +/// +/// // Load pdf file contents by page, ignoring any errors +/// let contents: Vec = loader +/// .load_with_path() +/// .ignore_errors() +/// .by_page() +/// +/// for content in contents { +/// println!("{}", content); +/// } +/// +/// Ok(()) +/// } +/// ``` +/// +/// [PdfFileLoader] uses strict typing between the iterator methods to ensure that transitions +/// between different implementations of the loaders and it's methods are handled properly by +/// the compiler. +pub struct PdfFileLoader<'a, T> { + iterator: Box + 'a>, +} + +impl<'a> PdfFileLoader<'a, Result> { + /// Loads the contents of the pdfs within the iterator returned by [PdfFileLoader::with_glob] + /// or [PdfFileLoader::with_dir]. Loaded PDF documents are raw PDF instances that can be + /// further processed (by page, etc). + /// + /// # Example + /// Load pdfs in directory "tests/data/*.pdf" and return the loaded documents + /// + /// ```rust + /// let content = PdfFileLoader::with_glob("tests/data/*.pdf")?.load().into_iter(); + /// for result in content { + /// match result { + /// Ok((path, doc)) => println!("{:?} {}", path, doc), + /// Err(e) => eprintln!("Error reading pdf: {}", e), + /// } + /// } + /// ``` + pub fn load(self) -> PdfFileLoader<'a, Result> { + PdfFileLoader { + iterator: Box::new(self.iterator.map(|res| res.load())), + } + } + + /// Loads the contents of the pdfs within the iterator returned by [PdfFileLoader::with_glob] + /// or [PdfFileLoader::with_dir]. Loaded PDF documents are raw PDF instances with their path + /// that can be further processed. + /// + /// # Example + /// Load pdfs in directory "tests/data/*.pdf" and return the loaded documents + /// + /// ```rust + /// let content = PdfFileLoader::with_glob("tests/data/*.pdf")?.load_with_path().into_iter(); + /// for result in content { + /// match result { + /// Ok((path, doc)) => println!("{:?} {}", path, doc), + /// Err(e) => eprintln!("Error reading pdf: {}", e), + /// } + /// } + /// ``` + pub fn load_with_path(self) -> PdfFileLoader<'a, Result<(PathBuf, Document), PdfLoaderError>> { + PdfFileLoader { + iterator: Box::new(self.iterator.map(|res| res.load_with_path())), + } + } +} + +impl<'a> PdfFileLoader<'a, Result> { + /// Directly reads the contents of the pdfs within the iterator returned by + /// [PdfFileLoader::with_glob] or [PdfFileLoader::with_dir]. + /// + /// # Example + /// Read pdfs in directory "tests/data/*.pdf" and return the contents of the documents. + /// + /// ```rust + /// let content = PdfFileLoader::with_glob("tests/data/*.pdf")?.read_with_path().into_iter(); + /// for result in content { + /// match result { + /// Ok((path, content)) => println!("{}", content), + /// Err(e) => eprintln!("Error reading pdf: {}", e), + /// } + /// } + /// ``` + pub fn read(self) -> PdfFileLoader<'a, Result> { + PdfFileLoader { + iterator: Box::new(self.iterator.map(|res| { + let doc = res.load()?; + Ok(doc + .page_iter() + .enumerate() + .map(|(page_no, _)| { + doc.extract_text(&[page_no as u32 + 1]) + .map_err(PdfLoaderError::PdfError) + }) + .collect::, PdfLoaderError>>()? + .into_iter() + .collect::()) + })), + } + } + + /// Directly reads the contents of the pdfs within the iterator returned by + /// [PdfFileLoader::with_glob] or [PdfFileLoader::with_dir] and returns the path along with + /// the content. + /// + /// # Example + /// Read pdfs in directory "tests/data/*.pdf" and return the content and paths of the documents. + /// + /// ```rust + /// let content = PdfFileLoader::with_glob("tests/data/*.pdf")?.read_with_path().into_iter(); + /// for result in content { + /// match result { + /// Ok((path, content)) => println!("{:?} {}", path, content), + /// Err(e) => eprintln!("Error reading pdf: {}", e), + /// } + /// } + /// ``` + pub fn read_with_path(self) -> PdfFileLoader<'a, Result<(PathBuf, String), PdfLoaderError>> { + PdfFileLoader { + iterator: Box::new(self.iterator.map(|res| { + let (path, doc) = res.load_with_path()?; + println!( + "Loaded {:?} PDF: {:?}", + path, + doc.page_iter().collect::>() + ); + let content = doc + .page_iter() + .enumerate() + .map(|(page_no, _)| { + doc.extract_text(&[page_no as u32 + 1]) + .map_err(PdfLoaderError::PdfError) + }) + .collect::, PdfLoaderError>>()? + .into_iter() + .collect::(); + + Ok((path, content)) + })), + } + } +} + +impl<'a> PdfFileLoader<'a, Document> { + /// Chunks the pages of a loaded document by page, flattened as a single vector. + /// + /// # Example + /// Load pdfs in directory "tests/data/*.pdf" and chunk all document into it's pages. + /// + /// ```rust + /// let content = PdfFileLoader::with_glob("tests/data/*.pdf")?.load().by_page().into_iter(); + /// for result in content { + /// match result { + /// Ok(page) => println!("{}", page), + /// Err(e) => eprintln!("Error reading pdf: {}", e), + /// } + /// } + /// ``` + pub fn by_page(self) -> PdfFileLoader<'a, Result> { + PdfFileLoader { + iterator: Box::new(self.iterator.flat_map(|doc| { + doc.page_iter() + .enumerate() + .map(|(page_no, _)| { + doc.extract_text(&[page_no as u32 + 1]) + .map_err(PdfLoaderError::PdfError) + }) + .collect::>() + })), + } + } +} + +type ByPage = (PathBuf, Vec<(usize, Result)>); +impl<'a> PdfFileLoader<'a, (PathBuf, Document)> { + /// Chunks the pages of a loaded document by page, processed as a vector of documents by path + /// which each document container an inner vector of pages by page number. + /// + /// # Example + /// Read pdfs in directory "tests/data/*.pdf" and chunk all documents by path by it's pages. + /// + /// ```rust + /// let content = PdfFileLoader::with_glob("tests/data/*.pdf")? + /// .load_with_path() + /// .by_page() + /// .into_iter(); + /// + /// for result in content { + /// match result { + /// Ok(documents) => { + /// for doc in documents { + /// match doc { + /// Ok((pageno, content)) => println!("Page {}: {}", pageno, content), + /// Err(e) => eprintln!("Error reading page: {}", e), + /// } + /// } + /// }, + /// Err(e) => eprintln!("Error reading pdf: {}", e), + /// } + /// } + /// ``` + pub fn by_page(self) -> PdfFileLoader<'a, ByPage> { + PdfFileLoader { + iterator: Box::new(self.iterator.map(|(path, doc)| { + ( + path, + doc.page_iter() + .enumerate() + .map(|(page_no, _)| { + ( + page_no, + doc.extract_text(&[page_no as u32 + 1]) + .map_err(PdfLoaderError::PdfError), + ) + }) + .collect::>(), + ) + })), + } + } +} + +impl<'a> PdfFileLoader<'a, ByPage> { + /// Ignores errors in the iterator, returning only successful results. This can be used on any + /// [PdfFileLoader] state of iterator whose items are results. + /// + /// # Example + /// Read files in directory "tests/data/*.pdf" and ignore errors from unreadable files. + /// + /// ```rust + /// let content = FileLoader::with_glob("tests/data/*.pdf")?.read().ignore_errors().into_iter(); + /// for result in content { + /// println!("{}", content) + /// } + /// ``` + pub fn ignore_errors(self) -> PdfFileLoader<'a, (PathBuf, Vec<(usize, String)>)> { + PdfFileLoader { + iterator: Box::new(self.iterator.map(|(path, pages)| { + let pages = pages + .into_iter() + .filter_map(|(page_no, res)| res.ok().map(|content| (page_no, content))) + .collect::>(); + (path, pages) + })), + } + } +} + +impl<'a, T: 'a> PdfFileLoader<'a, Result> { + /// Ignores errors in the iterator, returning only successful results. This can be used on any + /// [PdfFileLoader] state of iterator whose items are results. + /// + /// # Example + /// Read files in directory "tests/data/*.pdf" and ignore errors from unreadable files. + /// + /// ```rust + /// let content = FileLoader::with_glob("tests/data/*.pdf")?.read().ignore_errors().into_iter(); + /// for result in content { + /// println!("{}", content) + /// } + /// ``` + pub fn ignore_errors(self) -> PdfFileLoader<'a, T> { + PdfFileLoader { + iterator: Box::new(self.iterator.filter_map(|res| res.ok())), + } + } +} + +impl<'a> PdfFileLoader<'a, Result> { + /// Creates a new [PdfFileLoader] using a glob pattern to match files. + /// + /// # Example + /// Create a [PdfFileLoader] for all `.pdf` files that match the glob "tests/data/*.pdf". + /// + /// ```rust + /// let loader = FileLoader::with_glob("tests/data/*.txt")?; + /// ``` + pub fn with_glob( + pattern: &str, + ) -> Result>, PdfLoaderError> { + let paths = glob(pattern).map_err(FileLoaderError::PatternError)?; + Ok(PdfFileLoader { + iterator: Box::new(paths.into_iter().map(|path| { + path.map_err(FileLoaderError::GlobError) + .map_err(PdfLoaderError::FileLoaderError) + })), + }) + } + + /// Creates a new [PdfFileLoader] on all files within a directory. + /// + /// # Example + /// Create a [PdfFileLoader] for all files that are in the directory "files". + /// + /// ```rust + /// let loader = PdfFileLoader::with_dir("files")?; + /// ``` + pub fn with_dir( + directory: &str, + ) -> Result>, PdfLoaderError> { + Ok(PdfFileLoader { + iterator: Box::new( + fs::read_dir(directory) + .map_err(FileLoaderError::IoError)? + .map(|entry| Ok(entry.map_err(FileLoaderError::IoError)?.path())), + ), + }) + } +} + +// ================================================================ +// PDFFileLoader iterator implementations +// ================================================================ + +pub struct IntoIter<'a, T> { + iterator: Box + 'a>, +} + +impl<'a, T> IntoIterator for PdfFileLoader<'a, T> { + type Item = T; + type IntoIter = IntoIter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + IntoIter { + iterator: self.iterator, + } + } +} + +impl<'a, T> Iterator for IntoIter<'a, T> { + type Item = T; + + fn next(&mut self) -> Option { + self.iterator.next() + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use super::PdfFileLoader; + + #[test] + fn test_pdf_loader() { + let loader = PdfFileLoader::with_glob("tests/data/*.pdf").unwrap(); + let actual = loader + .load_with_path() + .ignore_errors() + .by_page() + .ignore_errors() + .into_iter() + .collect::>(); + + let mut actual = actual + .into_iter() + .map(|result| { + let (path, pages) = result; + pages.iter().for_each(|(page_no, content)| { + println!("{:?} Page {}: {:?}", path, page_no, content); + }); + (path, pages) + }) + .collect::>(); + + let mut expected = vec![ + ( + PathBuf::from("tests/data/dummy.pdf"), + vec![(0, "Test\nPDF\nDocument\n".to_string())], + ), + ( + PathBuf::from("tests/data/pages.pdf"), + vec![ + (0, "Page\n1\n".to_string()), + (1, "Page\n2\n".to_string()), + (2, "Page\n3\n".to_string()), + ], + ), + ]; + + actual.sort(); + expected.sort(); + + assert!(!actual.is_empty()); + assert!(expected == actual) + } +} diff --git a/rig-core/src/providers/anthropic/client.rs b/rig-core/src/providers/anthropic/client.rs index e2458d9d..e949609b 100644 --- a/rig-core/src/providers/anthropic/client.rs +++ b/rig-core/src/providers/anthropic/client.rs @@ -113,6 +113,13 @@ impl Client { } } + /// Create a new Anthropic client from the `ANTHROPIC_API_KEY` environment variable. + /// Panics if the environment variable is not set. + pub fn from_env() -> Self { + let api_key = std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"); + ClientBuilder::new(&api_key).build() + } + pub fn post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); self.http_client.post(url) diff --git a/rig-core/src/providers/anthropic/completion.rs b/rig-core/src/providers/anthropic/completion.rs index fab42544..81b8b14c 100644 --- a/rig-core/src/providers/anthropic/completion.rs +++ b/rig-core/src/providers/anthropic/completion.rs @@ -47,16 +47,14 @@ pub struct CompletionResponse { pub enum Content { String(String), Text { + r#type: String, text: String, - #[serde(rename = "type")] - content_type: String, }, ToolUse { + r#type: String, id: String, name: String, - input: String, - #[serde(rename = "type")] - content_type: String, + input: serde_json::Value, }, } @@ -73,7 +71,6 @@ pub struct ToolDefinition { pub name: String, pub description: Option, pub input_schema: serde_json::Value, - pub cache_control: Option, } #[derive(Debug, Deserialize, Serialize)] @@ -94,10 +91,7 @@ impl TryFrom for completion::CompletionResponse Ok(completion::CompletionResponse { - choice: completion::ModelChoice::ToolCall( - name.clone(), - serde_json::from_str(input)?, - ), + choice: completion::ModelChoice::ToolCall(name.clone(), input.clone()), raw_response: response, }), _ => Err(CompletionError::ResponseError( @@ -157,9 +151,20 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { + // Note: Ideally we'd introduce provider-specific Request models to handle the + // specific requirements of each provider. For now, we just manually check while + // building the request as a raw JSON document. + let prompt_with_context = completion_request.prompt_with_context(); - let request = json!({ + // Check if max_tokens is set, required for Anthropic + if completion_request.max_tokens.is_none() { + return Err(CompletionError::RequestError( + "max_tokens must be set for Anthropic".into(), + )); + } + + let mut request = json!({ "model": self.model, "messages": completion_request .chat_history @@ -172,38 +177,48 @@ impl completion::CompletionModel for CompletionModel { .collect::>(), "max_tokens": completion_request.max_tokens, "system": completion_request.preamble.unwrap_or("".to_string()), - "temperature": completion_request.temperature, - "tools": completion_request - .tools - .into_iter() - .map(|tool| ToolDefinition { - name: tool.name, - description: Some(tool.description), - input_schema: tool.parameters, - cache_control: None, - }) - .collect::>(), }); - let request = if let Some(ref params) = completion_request.additional_params { - json_utils::merge(request, params.clone()) - } else { - request - }; + if let Some(temperature) = completion_request.temperature { + json_utils::merge_inplace(&mut request, json!({ "temperature": temperature })); + } + + if !completion_request.tools.is_empty() { + json_utils::merge_inplace( + &mut request, + json!({ + "tools": completion_request + .tools + .into_iter() + .map(|tool| ToolDefinition { + name: tool.name, + description: Some(tool.description), + input_schema: tool.parameters, + }) + .collect::>(), + "tool_choice": ToolChoice::Auto, + }), + ); + } + + if let Some(ref params) = completion_request.additional_params { + json_utils::merge_inplace(&mut request, params.clone()) + } let response = self .client .post("/v1/messages") .json(&request) .send() - .await? - .error_for_status()? - .json::>() .await?; - match response { - ApiResponse::Message(completion) => completion.try_into(), - ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)), + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Message(completion) => completion.try_into(), + ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)), + } + } else { + Err(CompletionError::ProviderError(response.text().await?)) } } } diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index ae874b21..2f085dca 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -57,6 +57,13 @@ impl Client { } } + /// Create a new Cohere client from the `COHERE_API_KEY` environment variable. + /// Panics if the environment variable is not set. + pub fn from_env() -> Self { + let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set"); + Self::new(&api_key) + } + pub fn post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); self.http_client.post(url) @@ -192,8 +199,10 @@ impl embeddings::EmbeddingModel for EmbeddingModel { async fn embed_documents( &self, - documents: Vec, + documents: impl IntoIterator, ) -> Result, EmbeddingError> { + let documents = documents.into_iter().collect::>(); + let response = self .client .post("/v1/embed") @@ -203,32 +212,33 @@ impl embeddings::EmbeddingModel for EmbeddingModel { "input_type": self.input_type, })) .send() - .await? - .error_for_status()? - .json::>() .await?; - match response { - ApiResponse::Ok(response) => { - if response.embeddings.len() != documents.len() { - return Err(EmbeddingError::DocumentError(format!( - "Expected {} embeddings, got {}", - documents.len(), - response.embeddings.len() - ))); + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Ok(response) => { + if response.embeddings.len() != documents.len() { + return Err(EmbeddingError::DocumentError(format!( + "Expected {} embeddings, got {}", + documents.len(), + response.embeddings.len() + ))); + } + + Ok(response + .embeddings + .into_iter() + .zip(documents.into_iter()) + .map(|(embedding, document)| embeddings::Embedding { + document, + vec: embedding, + }) + .collect()) } - - Ok(response - .embeddings - .into_iter() - .zip(documents.into_iter()) - .map(|(embedding, document)| embeddings::Embedding { - document, - vec: embedding, - }) - .collect()) + ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)), } - ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)), + } else { + Err(EmbeddingError::ProviderError(response.text().await?)) } } } @@ -500,14 +510,15 @@ impl completion::CompletionModel for CompletionModel { }, ) .send() - .await? - .error_for_status()? - .json::>() .await?; - match response { - ApiResponse::Ok(completion) => Ok(completion.into()), - ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)), + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Ok(completion) => Ok(completion.into()), + ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)), + } + } else { + Err(CompletionError::ProviderError(response.text().await?)) } } } diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 8262e6ce..f931a27f 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -241,8 +241,10 @@ impl embeddings::EmbeddingModel for EmbeddingModel { async fn embed_documents( &self, - documents: Vec, + documents: impl IntoIterator, ) -> Result, EmbeddingError> { + let documents = documents.into_iter().collect::>(); + let response = self .client .post("/v1/embeddings") @@ -251,30 +253,31 @@ impl embeddings::EmbeddingModel for EmbeddingModel { "input": documents, })) .send() - .await? - .error_for_status()? - .json::>() .await?; - match response { - ApiResponse::Ok(response) => { - if response.data.len() != documents.len() { - return Err(EmbeddingError::ResponseError( - "Response data length does not match input length".into(), - )); + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Ok(response) => { + if response.data.len() != documents.len() { + return Err(EmbeddingError::ResponseError( + "Response data length does not match input length".into(), + )); + } + + Ok(response + .data + .into_iter() + .zip(documents.into_iter()) + .map(|(embedding, document)| embeddings::Embedding { + document, + vec: embedding.embedding, + }) + .collect()) } - - Ok(response - .data - .into_iter() - .zip(documents.into_iter()) - .map(|(embedding, document)| embeddings::Embedding { - document, - vec: embedding.embedding, - }) - .collect()) + ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), } - ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), + } else { + Err(EmbeddingError::ProviderError(response.text().await?)) } } } @@ -510,14 +513,15 @@ impl completion::CompletionModel for CompletionModel { }, ) .send() - .await? - .error_for_status()? - .json::>() .await?; - match response { - ApiResponse::Ok(response) => response.try_into(), - ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Ok(response) => response.try_into(), + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + } + } else { + Err(CompletionError::ProviderError(response.text().await?)) } } } diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index e072c9b6..b0d3125a 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -231,14 +231,15 @@ impl completion::CompletionModel for CompletionModel { }, ) .send() - .await? - .error_for_status()? - .json::>() .await?; - match response { - ApiResponse::Ok(completion) => Ok(completion.try_into()?), - ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)), + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Ok(completion) => Ok(completion.try_into()?), + ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)), + } + } else { + Err(CompletionError::ProviderError(response.text().await?)) } } } diff --git a/rig-core/tests/data/dummy.pdf b/rig-core/tests/data/dummy.pdf new file mode 100644 index 0000000000000000000000000000000000000000..b1f3353fe7def45db0381ac4cc80656ad74e9461 GIT binary patch literal 13078 zcmaKz1yq!6*Y8mfq@*OI9Hc>DW*COV_im?LIy;tWuM z+nK_h;7*JJ0vL$q3WmEr!R(SmhxrWDr)%9rFdU@GVV_gr;gPm1sj*MyX9A$~GNPhNC1OhR)hNqPQ&>Ju;b2Qd zetit(-%kG>RPev!rRM1X2e2u@%n@Nyfjir~I+?(o0fxwuALOrsUrp^G}NB>py|E(X+|2<6<<^r>^H-Fs5 z91)JccEMoNu!OtAos^v5W(bkm{T-)&1n|GQFxbQ_ogHjoo(N(8qbd-tzy%P`Kb6h( z?@j-qPY?hMKCT!ES5cEGgix9n$<=sw2Sm`P zXpro7je$gkdH$haE{jLz588+X5`O1>$7wQ}X(dU+9kY#L?CqXqFMN}rVffFB97RyT z<&1W;)!q(~7wP`B$wQZRDqHgLb=w^$681Mu#pXJ@(Y#p=4Vz~`mY#dKv(?^y2RhHY z?D)e%ZcI%BE1f2zH$uzaNO5D^qXu>B)A1PbYXO2!$L+Fpb=mJM)`uA4$$PIvw~PZ> zBvYe2$*BuQ^y_STJ$2O#kxF!`ZLYd{YpqyGIZfV>#Nn~DNI{gb$r1Du^DRA zmdq&g@0h4bP^1vfvQ!urnRXoxu$f-E%3AMFt}rg5t+rBjoS?lW5W;*}CiL=|XVDPG z4eHt2bxx-fI%)^x-4#uPI?MpckMWh`yk4D>wW%YbB>X!pZ>*{7M`cb{p?`v zUwa{VU;P*61tg!AYF*w6(NMbpQfsO&=AMAO01C$F&t^Ixo57wB{R)!1h-m9g#-t+(eyt&$hU(Q3g;Sl`+nBAy)2J(0E?)- zhO)vK_U!ADuuKx^sN_YNG{=jmcH1xf2$pF)+a3K$S}ix*zl2(cTKBs0+q|c9 zlE6@|IGD&0YD#2rCSXsO0T(@3X zY5b+Zn7K3ey3J1IZ~2XT{|%J~^dtBu5AufV&kuDLcQ>J9*!S&+WR1fpSu>)$?N-Hp zZH+=G*Gp5LT3##DZ3`c$z<1UN4X zZ*{(yLmNwGzxl2(hg^PQ?qH;N-NDpEIFnni(4U0eRFOO$6}_gh-qp&7bR=or_WFhK zV)QEM9H!g4T3g6;m4*jFwhktplsIO3z>E_?vyyh4!A^t+#jh~T39mh=V3mrcG8<*e zpm13HiFqLfR+FgNny=EW%w^UYo{lT+5~_PAUX6Goru(#b)AiN!cFC^`Olpy0>G}5| z)F<79T&Q@5SAcm%QHEBcHh@Bn41)+G4iny{xXorLQ{7m3RVo&pNr|EQr+18KopHWp zLZ94)*$?`@VrK>8L@2vTcuH|2vkr7+HjrMP^31cj)K%^>7@-E}>|uIp)sp;1@(=2! z5&A7fTpIJ)MqSIb@Qlc7Kp03UdB-4d%YB<=5PWU526wCjV_ZMzG=*A{83ZXte!v23 z3_=!CS&ObQWG9l%PJeKbT~DDl5fu81h?sJ6Op0~|uM@0~?fs0)k9ws^^W1$S{Xx>W zEB#wRmY(b~Rlg-(rtoo5e~4~xe-9^wgKJ-QTy+UWvE#Z&A}lfTr!hFS?9KiF%THFP z1adlhOR7u};_<02Jy_()M=p%xLE~1B%stfU*u7+mO$s)*`@#)xTHWR+bb0(`of9iz zyg7u2KXF^)ukpAf)Ov)5J%)M|i;_xD+QOE^C;^LuCIy7v+;VY=Y=fLXdC%Cz&84Tm zi#|XIkCF6cDsvhqy?3ZuaHiU)UOx*@U)o>b8xmxUD#iO;tWBJquNX6Icj4p06hM zuVX@zA-)*5-*lk7qa~W;7#iX{GqZ$SQ*`aXjFv1LTAcH0LGdhcmR4DXjWqf1L;XwH zFN*SaJX+Xt>s`=)#Ne4Q`-cCfu|Nqn!b=K~gT*M(e#aOJdEa#?BIY4~qs2w{kI{&e(r&XUFNhSp&g5>|8;bXK1sA0S^j z>@7|0n#0&x%vmry;yFdus%1(&P_Rzuqa)Whu+Q zadt%0!{+S3hhvwiy^pp};07dECa_Y9A2-Mh5Iv+-NK~ps3fOWU9kuSDVT6L;I_7Qa zMYw?o`Fk{4j>QD2@}7-!^ZVg>RnO+C^M^KclD|*SnH7$|x9X=!tKfU;&gi@fhGSlc z`d2AvhvGMOn^&PXB-c}Du4!ik@op11aR8r9EDbAH}J}6%@<1t$(DiwK3V? z;dWZ-IXkNN3nL_CK+Z1Q%`YXBG?^t+jx0=uAH-IaGsGv<&lL`vb(s-n8rtTByjL4b zrT)@U(y`QcJKY31>U%Dyvup46wcGn${wixCS@Aox`^2Y;IeA5Mi|?MtCTlb_b=8!b z%eR-&4U($cB+uiU#1tLL*_u;qdpP&KSJ?ah7DkQnF&-5!(R`3A=?rPqk|cNmHqK>D z0_`phT2)P)Z+dh7=1bn5ge4cV`Iw~Iu8NS|{fqIZPm6yq*A~!e3>)=?(pgXh6_aQ$To?TIf;Ov)0J0^zYIdzJfkY7G;A&RZ3sRa$aBcnFv z`CZbiJZ3j#F~Q`Co+v3Qf@m-JP{AJ!?`8lL!m<+LF8&xF*S}R_=JR zp{&gV_cF;) z*5thwJehNEkqOpV{7TG#Ezy<9CJoyAz5HD}NB!JN7+WH^!o9|of#1*4)37~>f(>Xv zQj7;N0`GV;>n_F$N;YiG)`(+k$%QQsLWO2gE&;Lk#pKb~19- z8cq8#<#hDr8m3Kuu_c!8CzVHkn-LB&J<7CUcaa}QL2A#apj`j*{vu|o^df;qx;5E0 zumU|}?7i*I+XSo__7dz>xw=b*2ExLAHb%PkMfkx1&%|(XAt<;%YmRq5lA*OUyX4+3 zwYk>%YpA1H`LiVbtK~t3pbB)lo{~lQ?!jNPGe4SE;1Ls{b=3m4J}>2%i$ZWwKw zeQZ#2_p!Z#(a}lr7+DBUjz4gmcis=-D;wf?n>z<@cg#L^{|q}2=tP`j1ApT|sFA!~ zFEjJ<6g&#Q>PqnQ^`_?ymW{ba{^4|gmNS{B*>2kB`F0TKqoi?cYA%v`1~xwXsy~?7 zK1@Qx&Z?c1%fR;4y#`HW7G*yDrgXPR#4Bb<)?Cj@gy$U*)_&#z)p&dISlr;T+cPTm zK0;G4YuY>XmWf`2JIO&Hqeym2&XoI&b}k#ac%I10 zmQsyXZ%U%Vu(;Q3*(~7oMHk9O4^|;p%+s}V42U>h1_>$&{PXuy8AN7dPLP#C9 z^vGg$Qm<5L;uTfq8V9U%qJzref8@UI?n_RY_P+S+JYjkN`;XIO3aoMh* za$1(OR@-dAQ@j!(Bk0Y$ho_cM?OM!BwcO&V;AX;DFm zcRP{?{87{!mbzjAZ?g zSh4tB&8DMx92{#7aq~%XFoK#655{O(L_;uVtV^w zniil-O-s2-S^ZoVKBBM?T8MFZt%8DHj-@P;wELuBHeoq1Pzr_^q~q1}>Z@gbct&u( zPpI23o(Fs17Pklh2d|!- zCy36`*<(&RoLqec4goyQiGGrxRIWsj2d}=jPB6jM^Wb*E3aZ_v?gw!Bq9U3b#enc} ziDO?LmTpd$`=@V+b#@3#3YnE;HFHubkP;0#?QWD!O}>8RW+ⅈ2_`ouxmXhOPWQC zoh;RjT#%ar(Gw+|NeJ_qbU?R!o`}k3DMf`1n>@jK753Qzur@$Dp#{}l zbZpfH38#!Q7!(+8L?!cOI@V;M!{5sn-Pd&EgnK(9E(BTK0$;dO?5uu+=4U02Y+eX# zFw4Dh7mA9|kURrOO@b&;FZl#XT69vjf>H$nE!}F}Nd?R-Uzpuq;j?>NH6BYIqpzec zL!%C)kJ;PYH=&7;yXC$ldbhjfGrN;kaE^R5?HKCfb2@MMvP-;sntlJ1WfyD>8`~@W z?xY`WrL87Dm@O@NP#s5U{;e*em4a256|Hxg2bsH=v-!h6+MXWMw%uElb{?evXba>D zQF=Gh(|L7yrhv?-xB5NT*3^zm*7<9Z9=id%SaSK`_d`(}z^!J6jq?s}ft4p17 zjaYl@mn$y_7T@Ks;oj-{A8e6V8Nk^xUM z-GNV3>7JWi}jYtTsb!JD5*rNQ8b^PHIGVl%W2zX zsN>yPk(Xt0TeHeqTTv4RF02of2~{sa^3FIvGizT)2ssj{rN}Ybf5w_6;*5VD&w&2@ zrDk1BdIE2qq#+s4-M7pWvMxFxthuZbBVTi~ zO28QnO2L$e^$Q!7B=FQd38xK;HFtxJ`RIvnZom+xH+>x?5$KWXm&^y61D_oftC`sE zu!YLLd5%19H)ortEKRDKBPDs33m@7~4X!mcaJ(p{WP7XA@R5-(D)*Pa8%$rCMo(3* z9V-Wt#N7;00XlklCQAetPTR;NuAPE|a|`6>U$Y#uArQ(rp1cQKID&Q*wI5e+=R7Xloppg+weG2>D9=R631U59ZeROC`M zCNIOgBRuwpwKKp1g*p#iF2XF12uLksBwY9AcMz=+WE`zI^h4-vZwQ6ysb zlq+J$5xxNIYBm%+cJ45MC^x3|I9m1V{@)3aTu34j$hoiA0X@ej@M1yZ7g$2Wp@H8T zfoU7KBq3Zcop8g2E6#9ykWIMAQK{or({My_$AWVss9_|=Pku742$6(Q6LmfBQBH{0 z3c2BU+3guY3&eMXyksyp!V1Cb7ol3e1!9~&jVn|N@t}H1oziKs35nU@qU-}G1XM>5 ztodv*r}a33^#Zt_$-Gwu5`M!ic5HdxN0rj;v3acYgaEggIx8p|_U7z8@KrTN5x}6! z5miq8H|jJHyBd{MrC#30kXxnL`=@c3 zxu}kK!#!t1Mnvnzajrn!+ovf)EbA(1!CGUEuZGnH!Z(39r%q|{9?#aztDX`8Dni5~ zG`rvtnvQzUqVX!jczUdjBAuRUINJnx^>9>So$6>+RIM}`#(D7f*E@tdVc7)`M(i8W zu6s4%J0j~9`e4LXVVp`dk;+2f_+fB%azt=;TlL6Sp*pg!D>pr>5MO#aOZEFjg=Snw zIk|mseMDn7PYN*$f5wQWAi@qWDAUehNQ1B#J7V|+~4iW^kVc+4d z1U`H-T-A>|?mUk;?#7IG-4Pw}+Ufqefyh0Hfy6z^?*j5LYV!Cl=5JItFJAyCgEjHu zo%x_ZGfUhM^ zw%>jB&Bex2@oT~U;@2Aj4_chreld3!CJ4*^)|4PW!n!H&z;L}Q`3Kwg*WDjcfrs)x zHy4Q4pEvvWF=g?eyI&XnM8$vl|9Tq~|D69p6UaRCzX|7jelV7OuhGsy5%|DfYiJ7l z!v-V7Na%eySQ+~f%~D)?m1nLb=U_9M?KVigJmKl1~FZQ&vES zEgPib?2y zf9Rs4c$joi%IyO(orWL=0(F^yaID^rr}}T9McL4fF3fU;FlW} zEt{(gadb)UL*2nL-;T7n>?HfGKetK`5|18mU99I*P{dE09+QAeehx_r{1HfgXz=0< zXd-g03`XnP{-n;Rb4Z6hVIhQ@{#&@TS+b}mP^D4$yK)gb#cZC!*t^QRl|43_RSXCz zmYz3-BS%(bsq)x&E`?vvmzt)pW`A<++NjT?GZ}!Qyo<4Z=2>K&8k2mB))y{i2#M-jGj^_z6Q1+yP zOR(U9Spy8yD&152RpoSTx&bT?N6^pHB*YJd*Sk zM@58<<&Kj{Oa^L1^%ZGmO@WC_5gC#g;u|M-TVlNsqg_qJLK7=(_1DO?PFT6IjySS|W<=gYWj~_5eaYO$wF7mYgiQ zA8-xJ{|&N%f7I!vt6Yni7HkBgmioWYeN~*tnAb5ZU5@W+Th}H|6gej`FlCiK{^TiM z!isp)*224F81ifuwCIcd`qvXJz3}!P*O`fL>U}g;?km&2qB1Msd~W5-=67tHlGCA4 zd%qH5+#qS%`0_haL}a{tOjB_bf)BwDcd*R@2N`m7yR&$H64oqX!QiP`>*iDEI)`^$ zI^0m@E6u7ky>lzfY7!}u?tC5>pS%(4ajw?9z_etDPHnovG|4;`L;^~MtwXqu1kDJx zZ^e36jG-XSrU6Ijbew=PpB+EDks)^g#0{&-=gD1g`qgDqJjIfkPHzIZ5JGJ3sPt*UZ%{ywQ>(^W`Hae{Q2R%~dwy(ePUQPbG46 zFOt5=#rpU}*KLu|)`zOt@>1H7GB^{o*z0CNY(MdqsU=KTD1+)JHOg};nDg^lf#C*h z^jZzm=bR}(b{G5&hJQuY09^aA*rDCULWighV;auF@pND5vJ; zOyu12n~reNJe>>=^o~B$+9vs}m-X|MlEU4gT*<6QF+*l=w z6=NVShaYO9)?v0gDc^0Q=nXVf~#< zi&7W)vza-fX6qK_@61V~leG%y2R+74Ty%4#hpuj&NA)%-N695Lk(8AN$lkjfayaES zs%y(juc7RPGEJt9w5wekFt4qT8ytHW6fo5fXLu{F13tZyk;8_!SL`6=FEPVQ5s_(2 z0>kwMXwKTdTx0<5Iq^3KO{(4R9AEmtE3XvV zWLtdt#mYN}4m(D}avckQzBu%A3HdB?f4~(8mlesamHom!qijixPEhU z{wyymx%^jB;b~|pfS#-@lC(XuocolWKI1bH>->ky=`1#0h2)<2WjeF7A!*<91#PlL z!5o6v?jl>QOOo_0kE-`K5kfHqiv9Eyf-+~Lt8SRMO*VG< zUN(D2v(IgzEq{7xZsXWTYA!d-UsXTR_X?|H88h{jT=VWa`f13KYQK5Sp6#@DI!~RG z(DuaukMQ#Xk7;QGJp*iIv?W$)f38d$g?6^gB3^Pk=U-;b)ltJv21oh|L6sd(*rM{jr8dN(gbBRj-Mr*0z6BmV5%ROmF;WhchWpLxi}C-^d}8{RLMeb zTSF+m2|_BJnec_MW@P+5wt9U$^Cy(!b~kdA4f8dJuSOym+Hg&U*K+0f8O!v2xZ-AOYaKRbTfi$u%tLPhD zX5{FoCP$JBn_jH9AAd@wQGH~m(FEb%gJPV)f70Lc(r~WxZ4|hU-hIT|_AB0bX`zsm zxfq+dn2@O(&!2h1FliOrn%?ttIJ@hGosbV-&UO)XpXgB4dtIzeB5Day+Jo<5`CQ$< zw(iEefZc6j-0ejTN-drG}LbtV$Isr>#$7B`|cR z&2XxBY(c@WN;X-f*d#L443xG$m&PbpoVS0Pe*H#IYU6P3ockr&yGp?;yZ5c8(yp?Y zQ?H@KVaWP$tmdYTL~#mB&y%2OpzVjW0`&ETEdA|F15i?Fh>N(g%JPDBgXB`{5aJr@ z7Uz|v@qrn;5xWt`Ek&xZLB%_=Rz>~O2O9PH&j-mLqt&8aY7501#INQ@nWd5x%E9_? z<&@cCklK^^J_<}6uFj$I;MawqKb@_#bIraX)x2A23@4%_O^~e^T(stS z*bsHuXUdcSNtU)0HTBCn97?mb-7f_gDKFJi!d;WqlU>=>*9oo&>-a7Y!&7 z^V_34L9zy3WB20NdhCiCx9{0=Qtooq7kmD|9$~Sb<(NWAay;Lp$ZmLjzgSNbnr7gA zS>N7rvG)YB{xoLx!#EQv#tsg-a`EUyZ?&RNB6wX6I9kx3UdKfc!T!my@3pyU(eLRP z0#yRl@O{>C!eV*Za(Vs2uiEau3)JX}ru=MaE-xCVKvj@^itulu#VSm2g4S9U-A|?8 zHA$Ib-xu93+9-CyG&Zt>A1X$q zIQPUG$>s2|ACi9@$Wu+f?=9=uW^#yiA8|Nzp>X=d!kTSey&oZa(F|=a+2@Aj(JUqK zea51%*}_LBU`p)1v680|(SA3dWHGimQDoY%g+JL2_Jo1uujnkrpjcP9;Oz76SkM|u z!xJ%k|6m-$vMIea6mvVMw-0bpkk$NUI_T@caGE~Ydy=-cSP%ch`{@4AGNM^OwpJ3z zH_orOver6u7z>*tydR~(+Sv!Ua$8(!N~2ZV?Ak^{Vu15QJ388_5y{nll13hHe<;g< z-9XX&TL%SX^NWtHuTpVL*-q4J?CUe@CMLS90kYYzSDUb?{ax9#*gudy0OMM(|iGnv{G?~DbCI>3qG-F94cD8T0iFw=q`X3v=Zn@FK-l7Y7Hj;v^S*`=(dT?i3;b?Qv+G0tr^t1cIVO*BoXw~*Z zpu_p0jkjpT`*u(CIQhH|DBL@&wKC#s>N7V9xQqdV#8XhF?9X0*wPmJVwlswN8P zpU%+DHLG>%{L`CIC@v@dZuOC(L_n;<}rVo`EE%H`M~ier?g#>VB%nw ze5QosATv7ieE8bgWV)(0++Ql9j|~*o&LAEfv_pjtJ=qXL+-iA}1!ioSI11o`-ATLb z)VXgy-Xc7NcR`-(i#}T|=Zoj%Bu{mmoeDJPbXCeMzXY?0I#h80x%zot;+zL?B$fl- zFA;Bz>7}_oT(W$1qYe30u+%P3s(pu)S)XV@Fio^JgqVuMutO z;1V>LD7%T`T@i}C0GA!crcHp@2K|u2KS1T!DK8NmsP50_@V-B&cbBlfaw!fDRkx$8 z<@TuCssd4PdoYZ>`?)%in9r`@e8Lsmx+SpzLBmUmW zJDu&eF$za!yF_2V&Q6kB71CR7OSx?cK=pu(q!9i7i~YDgF$yxSKW|5V)zm2O-e|da zL0?epPmKPWIVF*oyw#oJB97kv+J?j$Nm;**V&XHRP<6>PFt(x0yE1I24F2$z?sJ@< zQ(u4_#H*vA-YX!~Bt6|aC+bk(k5OC3tEMb&jU_7^<`=LQrK-I{SxR!YS31};Lv#u) zo?`p_l&9J7Q;P*A7-XnvIUWD3~|NdaQ>U}<^DHd`VUC;k;N5*JDWIJ zI=I+7{e#U_fY~DWU*W$VW??5wn2o#|qQ?ei?hJrnun9YxAkbR?H;{uB%*hTxSRNO6 zG_im|AXaV=1PFb^bm6ud00;*w7dID#9YL^)!W^XFmgW{N08S9X3I?(x*js4?Q)_7= zZ1;$JBT(MQ(}?`V0JA`$5LONjASV}qg&hoGNog^8m;(ZNHbr=E>Eij9S=G}S;kvY)nf)W03Ts!%|{Aamy-ji4`Am;G{9h15R{Yq z5lV)FfUI2Xh@k2NI6wehC=|-d%?$=UI`R*2`d{6gV1O=D zSOX|bMPYKz^Uiy7)LdxZ?tLi9tP(XKml!Ou0^ivlly$f4p%A4P~dWMrAn_2_^ zn-b|YF?%w4#mNH_S$~}Hf0yYa7XN?I?!UD7C_+TJ{J&KzZio2F;ZbD&(k6@4KkL*K zhA0_?k`d7RUqzyV$gY~bx}D|YhY7k&0f8XMf3JrS`hSh%1S3xSzvgnW z|JNQMC_>o(a~ueSsG0w510(LUf7@J~V2F#w>G2yDRZA~KWDsSqVsDSQzaArmxVs)p b{GZWsc7Zv${1qz@m=g-ZprsX8l)(5u0MD<@ literal 0 HcmV?d00001 diff --git a/rig-core/tests/data/pages.pdf b/rig-core/tests/data/pages.pdf new file mode 100644 index 0000000000000000000000000000000000000000..8f9174ce343096f7d3567913ebbe5d6fbe9cde65 GIT binary patch literal 12414 zcmd6uWmKC@yS57yDDG}4?jaDoxVyVca4Q7&;$EP*7A;ns7I!G_Eyas#fug0z7y3M} zKJWha{;_}Td);d#T)AfEK4;ErX07CaPE}Hx4a5#crJFg}`-TbvZ~&am?NDF50BYJo z9UuTkEhp&X8%uztvxU1O#0koT3RHD}}0}ur2_OATCaTjI*;ftc|3z zg&RN};$#VNg}5>a3!}o`K6c9WkIpU303cMLxDo)U?Ck1j=J1as?tdgn+d4pDg8-!+ z%%BiSh=sEy1Qn($hIAl?d#ynG=6OKU&=Xxpd8ZX~6k8u>m=T8x<5@^xC% zmbGwM%a}dPsdpZ^*<*rSF7IENezc&+V43$vc>7A?hC==oPJ!$X5%Si4zAiHE3Au&G z3HkHU|C}1${|(g*3jZV3Y9^&4c@eS2%0lB~(2ypgw)Xs!E(*?{nE!?97FYb4XKmY@ z$`+<6k!7-^gce%;a$V4FjKoibGZY1C7Zm(io>psyh}|{T|3USy->Mhn7XPZ6|9?aE z|IM}jL-m&2(qC0`{zJ8+c{3Dp1qR+zkNK zfNcQ3V*t=&`^SO)bTJR?&%f;+|2LIA*4bZ_!FHnmP8r|-d1a5a^cQ8E|Gl!`YxtK% zb9;1yJIp*>s6Yi)xT?BBtYAlj)1O}dqX_=;5h_s9*3HGi%o}!`{Bah+>R;hB1zy~K!JemUICm0ySYTClt7?uIl9?y!K zX$ImmY7TFHEkxp5mf7}&M$t_HcXi$hhu#K(@apQ)N}OJ?c2fOZkb@MK53KO~5-bqY zk_3HF$k5O^sOr@6?je zE?kn<1EUm*AnLLGn2wdS%iOJxKs&k!J1Q)jl@@Wl-b2k|H;SkJFXNC!Jrb(;OvFarF(a3Kn&wmTS zMnsx^GMkV@cV-JRQbq%8y>+vt9JUs}#LbVm30A(^etT-ZI#s-Opz7oCyvB3jtFt>0pevN>3JYsp#Xu$dJYc0;D`LpP{NM=o1!*E5T76uu` zHU>3av*U zjF^YPY`yL&+8R+-iipLyeF`athIP4!7y{4;i2_`+=hPH}Jn1XsKc0QaT#VoTq9}Ye z4X{-}I`72hXcjX6?E9RunoLZKI#%cN`GoM^cj^xYGn;M^2xQuwlHRzXK z$w#USDn^UOv^)rvZ0k)N?@3>YMrL7m!!?OpRT`XBfTmLcRs9J+Ken0gv}o zXjmeXfm`Ij=L3~Gn=X#Xa)|Y&NjeGFMUAuv#o~nUw_-o$&Il@Pr!OzwdmETU=$SkP zc7){?9cScVm~G65RR$aNJDCTXde$#H00oC*yy^0K6Id{Y#kahbP`;yNN{?MEkTvoi z#5d_QcFx{w|GKhk@}Y?f5ivwNIu-@ngybE*y>g){>Qx+x9euUf+vY&sihIR|YcYmVn9$jLGeyl}662!HhV!|9A&9Y5akbvfR8*c4&LEKa~xMPfpf zZ{(!XUoNWbzN`FEORrWhQ`j^9bwRN4py8lV#m;0lcQWJr3(2RG0U%b+FyDomU}LpS zXMKBPyV4yoNCZxJi<(y%f1{>)AKM^x{qW@3?Jb4Xt0P=JBo@n^UQUg#4o$Dw!swh$ z?Z48fUs=!d=f75AwY z?-nw5v*j;0-8bTQx_5m3aRH41$pLw930InvbBlMzW8WUWcCAu>|3cxz4{{~=z~oMp zz+8E*H)`a$!(ZwNZtXq~Xf!Ce_g-`GWGdU-)?BdoQHTshFcs<|W{KAS?2>hE(H+B+ zMFk7)rCX70U()*+mvhzHU%$+W(m#DRF_;cyK>)*TxeMfzTl$P4oOjS?^3Q|r(>V!^ zPufaU4Wx{>jD4i_KD{ro_U4aA79rE_PLQc*z$RkW;I~p6`9`rz5Dgk{)Pu*`i0o6TH>E(UJ7RmOwo9C2EMaE&2UuH@hGC$XE9qCYyGg$2S^^JeQ-0^Mb zoeOBo=p1A?5u7kK(G^VL-&0tc0BA&V>C|r?@tgs_&`U5)H8BfK z9L5~Br5ZDs{Wc6gdIgBnpwd`o%_N;|28TCI(fF{aOu)l)oqkoeHCw;>!k4Q9}Aq&FXplA7b%}*u11PdZD@39KECe# zkRQg6nqm2|(9Z@j=>B?8wEA|CQON#B^42JB&|<*k?W#b6o8)hA#2t0e4` z4sIP%(X*6cGx3(`*T$F|8FW?>yuSLJp5SWKMH_v`E^sezp_9CpY*_`NMc-2nC%sQE z`*C*mSOq?$X)i;q9FwLEzh`gZeBIv8dhzzxpjM9F0fDYjobY67DQ2}JW&B&6?nnZ> z7sPSE>i4L%Tg1~eR@7Dm87A%3hHqxaFGCCb=sLo;zc%V-_HARwbuKMtrV}q!QWo^N zs<=T?CdioOe`tM4Y1TI{^S<38!+Vt*T5o>?X__~u4tK286(mmnsaVocq3o(|=t8j3 z$~qD%-Qo5P`}qVE3taZYOj+XTEY{oYJqbW~Mdn*A8zgN@ZN~D3rmsAjZ(=T=JC={_ z*xtvl!F5%Wcufoa*XO1bncnbhrzSKhtM?0m9+i{D2IFo~0cF z(=$NuPa9ST105kS+V+)CKl~~=<}}z+Txp}6+8awr!LHa;J71vBUTQ3UJo95zK}lJC zNriV6R_+r0s_42nQs9?CyO?9G6%KIUdSG#ju(}DTTLeIe1Z zFYr;ddYYLo%G2eATwghL|@jV{H}#r=S^ z*+r-gUGmJBz$`ijMs^OBhxR?|h2j`0{yIEyqXq^$o%~xRcv||0{D|oSv8qU*6t$A4$`!f@qmn8oZZZbE2AcxKpr4t^$|3?re$lpX!1mBu0{)iELy3%tf+rbJJuehjsTBd$Fb=fo4*u9 zi^aOFp{b~a4$)mJOGr{oZ&<2~5uYUClPY2_twFQcFoPo|F87dr9P8_!duGygqmz0A zHa_Vx!wxwRa!97Eo@#28LSOpya5-1IjMLD5z>naFD?(CxFtLwU$m|}p^YYc#Ds=|A zd?x1d2qd5O`!6>KC9jN=MBSk0f$g16%-P>3?7){rq6yl&EiGlHuW2W(INj4q-%(ag zy1Q#Umrv>&}mGlpn~DkApT`eDa?jH(gx4E{$jFgwveFZy%TdfZPb0Xhp!6i zxlP#niY!zsIrH?Q`FToG-Kx(f2f&i7Hrn9!jSXM*d>Iy~QG3gf0`Bfh0bd)_<0(`l zouy8~K3UvdCE4rht1pi7iT9FOnrRgbi0dt7F59r-bma<^!Utz05F&a{?eE+*x}VFL z2If^qAVoz*b$-rkeeC^fK;R#vk1Q$~Cv|o2?FoTpm6h)s@6{hi;{$WmM#l@c6BmvhsNKg{LyFw_I}<2?2o+n0Wee3Z`Khp$l1M>U*tI$I#>DlA@XC*gDe& z2<^*RDd1wy0}fA)%4?p5jk>D-JAAC%J3U;XuAliB`A5;$j;7^Fv94q;B=1lf;2J0z zh(cw!RA+re&%_QyBP8Rd4PG|-EH$lz>EImkTvBV8F$FNy#Za4r_3}aaW%+aY?tq*( zcywoCBch9943cctNvilMqb@#k#wr+2lx6k?b9$Z&)uCod6=l-l>3{f1iuFG z120>SqQldB2&B0_Ree|0t>C9`<;$$N*vBZi zh$c_hZxQn2IumT$u^N6LtYMxYs2=5fZo54kOG2@9=lDL0IgmKpJ-<=kdwcYmYUZ0- zKqdHvUBz?+DwE9KOs2pjg?4roV*QfY{gAXumw({tla2&&0H)HV*{1&{TpZ%uE5ab? z+GWV1ic&sz{x7lT*lp-G2;SH1lVtgK;3xP`;}A)%SG3Xbwte;Q=)SLe-O~pWyj&jh zJnI}j$GDGHG}!@;oO2fq$$FmZ=qE30I!Fq=jl(FzuqE7YyakeJ96B}fg*!(aEp@S~ z;@F_uFxzm;4t_WjuHstqfYzmWd^_LXt~w7>q$)m^j&;vanVxj0R08@ok)!dq?8>=y0Lyf*gzL7WylP92Dy^N6%z)FAhQB z=^R|nH!HWjS60!?@#`Oc%*~Ra)t1m0E#Fu;d!{eX?XT{$$6EvzSYourZzbmyQ&Wh? zaf(D5A|-Ao#fU_KTj#h@#wx4s6|V@baH@Ek1ECZz1HYlV)yNnbEFb&ykbd63CzH=z z+FaSu-wR-4aS$Ei2~2~l%Sd8LvNNdy;((J`cn zX!^?=cgf2%QzjYqnS%3?!o1BfhD6P1)aoIUSxVpGRG(v;apgO-C}8d@q*5r4qSc3+ zmfS1*rv$_X)Ce2om0j&G->lrw`d4;4)sV0vB|j%URu>2v=sFF+C)374MEJq$mvb8u zq)ZcNliC)Sd|Re}e><8HJmZbOU#xi&jY3yR-C!Z0rFa&qtb7utcm8qcc~y02*dc+Y znjS7G?z?9PbEj|HM5Sa!E_=!e2b`*;Z6m9@7F{7)>172r-bvR1PFx_u`rB06#7m<9 z1$S=#jWm4zn8Ps8&)yNApao3y?2c^f({>3@ZN+z}LQ|~*Z`g<%J>8{t6V}n%#tz9m z3fdgFe(BTBX=SB|Un>3dQhSe|s->e7`jwVRhVqhoa}8voz>h#VDjxqLd^2)sA<~d} zL4sY1_I0nXY55z4xacHTQusg_S?^L5d8e+?N$l;{QvNNrlN2yy*+a^|GCZ-rsvuw$t3bJiPe@e%4A5O;-qdzgI zzDW`dawOLl@{hj$QU6)T7^@s(xvi+1#4dCDv_Y}@`?dD0iJ+rTi0r|>#9YTs-6Yf` zqZBg#hO)dga^I|@jxZaIr9!{t>A{G81Y$(Wul?b2-gx4yEYHunEUg2>6pIqp7E;#D z=9M_!LYg~K`#9#6FF@M4#P*Udq2w_|2?!Q~8*)aqOh=r2I^mDkkk<6WZA8n+MkPHw8XpwQyMAH0@l ztht^ZP#Dq^u3-IP+ZR^QpN1MD!!db%%s!hE@yZ<%7xQu-3!?b$89`i@;atI&!UkG%fC!&mej?lfrh_5HKDRiQ~N z%>3eFkU^a&1D90Xgopt<#hOG!T?=a}3Oi(Y{GsG@HwLzjhLD4r{`Gk9G_DeGM;%sAz-&!AGMwcLT z*V!ZwX>umgN*lyVpWqf^SaEqp$xo6e3(C~TR&RTffmLjE@%I`Gqoz6WtXm6E2GFE^ z{hJXYqdvfgF7jl>sXRS{a17WX?r>%V5!UOa=02m>t!jx{H(&|2;^LSp!pN*@)}Km~ zIxmJ$I;xH|q71NA2C*?%n>HNDj`fIkdDQh{e!`Cx&=QrBiuWT-7IAq}@iy~GjHZu2 zNFXO*a;@8&|8}Qo-IMtN%4c7@J={~VG!Z+`O`uPRZF^oL>ny-|o|(q}zq99cY6$ z^Cd!!nplrIHYmxN7qvu+p<^6fx_^Fj+_)WzH}=jZCMt1Mr<*Yg!_c*oheYTpg*bp; zZ_wB8 z$G3vI8ST-J$#+yIZI1gH=m(=Yr#x9&;uV8nGcu9i;HOfE2%mV`XX9G4wq~|q)>5@_ z)KYb$gHhdlzLQz5l#p^9pmsvy?SZ3ETuN+CSaiL?5CBrp}=Ko5#jpa;^*Xz>uQkca`RD3HgHme|v zu9zl#&bjEK`xd(ar9y^RU?unr=^Lp7OK0i}OACotCF#~N*P6n4zzx!lu+lZZ`j#y# z2EU^1MGKcV@QWXMVl_2g%JX;U%Q9&|9^Y7kvRLmBQY##R&k-hYs>dbLpOl+%#|Thj zm%TI&qvL@CXU_zB6?0B-4eR%y**Nae?^bY8~lz20_uyJtQ@uS7}>$b5-+r{`$ zVnX}%+MCVNJa+cF}D@H)M3fLh{F`VleV!#qoq&f(9dH2ANt`qr=JGB?Jyw#mja zf25!Sk?6KmQh*<`E0zTG+%~7|V562`-48`+>30KD)onw6i z_ylhrY~;;ucUX+qVPF~OxCK-$&l7jNysXr!bzA}F>9yF8_jA2BmOY#T zYiw2#**@uERyklJ-=xpF=|b-|F!ZFN$@9Y*y{(mq`NEKC>>9>Y&j)m#LNx~JlBF3j!lOT;OyhaphJB^6Xe*`Z%5z9QHR7P?zAMdFL-I}D)!1CLd_$xr zqAz`u8EU4HZ7q5FN~7!5R8tstosFg?U0-~^z!RF6LJR|0N3z>5YnKk<;DCy`+gIMl z8+#w(GJOp$bdIVUt4(#BTC2BjgLxaOXPs*24|vBG7XtQYHBb#29fmJnTAB5p%@&fT zmAt-Dr!=7&jpdYq(5B;YO^=a~{M_z-=`AP0H9qzZ^u=*73S-*QEm8qN@I3T}5-8*AME#o^0ZX)n1r-n==Cicq#@lCafQ& zq{(9Xi}`wPqw>FIFbTM<1Ouu{+2_h$un8;{6VDttxVt)THH*GL*)n@xc~70oxCV8d z(r9P$`t+$55f2Xw-=zY9jK%nr5XC%rfj*%d`xSC9Py-Mnf zZp)!FrPT9{$~Q+XLIU4mKZ1N@*M75WsW>&nBM?(lJWgAZ?-tFpt0*3rb8ej@MXbqy z#oiJu(UPs3D{V;&H;=6)DK%E>EXY!vfGm8cDVa{=HJ^)09`(Y%b%9-yGH*O6Bm|wk z2^)n?ac&P=4rwb|Ch;TkR*6)P(sQD312|MfL}bS&KM8VHUQtlPvfo#$3G!@xoTxqnkL`HXR<2%YarWjOL(dvzY)?a)?XH3fw| zh@qrN+GP4Tdp~SEj92IB!T3SvdEdmy1WAwZ;X&O6MGAM_We;e_klsTB(L;mNLxcRn z>I#v<$+iD9NBtWZUtgw(YiX|$fK0R zqOW3bjn}uv*z>T-HBD{UH9*{2n5JM6s+0XX;9wU$EPh;d_t`q@MR2U)`U!m_IL49o z^B2~ZX!(&(7#SJ-Q*lcJnt|_S^YD~+tiF5I;*>y&0346OjY()Z*-7a2|+gV@6hmi<7_`*;k3NaGp zZU&-U9M+e0Vp54-B{bRWtk@Bme!e?Vy4O?G+t|vIAfDl&l62j69pRqzFF+%?W*dcV zovcn9k#Td8YqI(plX&S5?0Fsgvv9)WV9V*@3QFr_a;!!WuSR;*h2z7!JDDu&-c%%y z!MI|K1XLA9TdGn&6cyPL^h?*8&o;9IR-_sh6TGQ|`vpD>=i}fT3oiek1KqVdBoTwk z(>f$%u_97|MAm!9BZ@5N1SL6nP4s8?ig%{RYv%={-v+*|{5Y3)hTO2;L66VR7e-4) z5o~4JO@79g_VcZ~zcgE!{`_T5VL}qS-rMHM+%&vh0qom|Y7w&l_sJ6X=E|upf5Vd9 z>yO#9*<1;4R-u%ffzL{Y>LKMb&0o45;`0kXPi;8d)K-?=le}ly6Trlo>0=h4o)Fn{6+t8`nqX_H)7;-k&0T-IJH9D@by~I% zveBv{{m}SHV4z$V(yWymST${38P#K5>EE}jokskXS)z>*nf@s!k+0jJ07CFo>9WWhV$B}ej^rV(=H?La zK7sGelG`bgg~#4{g$^H^gWT>NURs;`ND~WXRp7%tQAgXasfud-hjRjjFFJJb3S+q- zejt>K<@$tuKzn)aeB_YolGR^*$jg54lGmf7PU99({JN*FSnH&zizNs{IsPm;Z3-kAQ4qm?_gb%e5z}q z;zX?ZH1(jpO=K%sbk$6sB+fEmOexj4g3ZC~8n~S9MwJ zI9#g6a~D2>ZxE-ah`}^v#9r(vPbQLe#frPm%Fy|u!z%`D8~Q+Cc)nC=_`#{>eI2mz z`qS+8jsp1?T=7$qq=?zHr|_g0l?9-Ek+RGIcm{-osW33+Eq>;c%D3`sQn;nlk{49- z-;;yfp>QswCkmO$b6q505?vJ{cwh^_0}UTBcIt+Kz0tf_s-Y7-K7K z61wIWg=Q0e!_QXavqqhX;_&8UI_d1hu6ib-p?H088k|URF&Np4e+}^3O0xg;Z4^hVOZDQ-LEs8>lywmDyN>F|mL?bvuIh}GU z4K5?Zwv6}EL(FS~P&6Ziwv<#jotefi?k8n&o`TqD8tDv5{cUX{9ca zDb(2!jGxBTtcG@BZzPnay?G32c)=^heRbB24q+TLbaKcY+UsD!|JfSLB48oeB z3Qm|j+DATTTz=c3k}#Fib#czHwt<_M-x{Pcz1&s0OSv|KoA=f?m~3~V3ws&a0gf_X zTOWBz_6lKfSL{N_CzVqDK{<<7SOyEB_xza)mk{a_^-Gp2WT(QxInOHAy?lZT4lk z^R!x&$)rISaL>`5_by}w?p|XZw19jYeOci=`O=5TOMLyyx#t@1Ips;<6x`Ye+|EL* z*)Myv{?Bf6h2ig?EJ=w>i>;|ZpIvkLqu)xcGu=q~$1UtVQ!Q2reu&z8IJUU?(s`DC zpRh_k(|OGcx{GlF-R*Wp?d5sK?B(5srw0_VlmG=-N&j|Pe&v=%qdNoKI_FfDF;AM(s^Iv)DYVfU($;zN7hOkgMIZZ&7- z0q^nljo%BryBj=zz#x^R|L|9C5A;v`>L1L~BVH#7akFr>b%8p&{(;*mn>oUeKC#~y zi=OqwYc~KG6)5Ir0i)>v{2biuoV;9MSmd#T#}FGQ2*l100&~FdB@HOV zQ5yi}X6NIFeOt~4gR#t9WFfZJHc$XB2o~k!;DSMGaxliq)@&VYmIKk|^9K0M{08S1L4j73A8<~%rpNkI${5`gSkx~#&z;Bf5pGZE? zBQE$aRM=zgA1(f&81Q>aVAPqonH%IW;jh9q+|8kX5MgqXj~R~>2cyy?5CSK0T%J~soREJpG!bn+f0HYWTA%?)%V-;5zR&48J4Pg9R zGgwf=-NnTL@<=BGU}zUAPzT1VaR4|u0o>d?>>zGVE--)_1Tf&?;bG_J=LGQ^0eJWT zys+TkK88((qz-_W`?2+(FeB)HNFsCc{0mKHl&ECi2fz+LzK`a2!issBynKlbKawhr zd6#b+7&Ilv3}zl6|8aaElPQjv5|MG`ZdLT*vUYA5!QV|9x@)GYFOWU9{!~1na5^+c zl=&O4d};beI&fsG`z2qa4@c2^xLx1K@s%o!{eXcis3h%937vC9uYt$6-! ztJ5n*>kiQIG3G0Tq%LON5S&=fJQW7<`2-z@jzRPMc#$s;K8Q@zA+{D(SQ>c5558Jo zyL9SbH?#pO@wyZHYLH%G)eJ~$Er~=og#L%h`7g}$cW*qJ0%qX<+C@@MF!=j%(f{rh zHrc;bj-?sQ#4u;U7{%YNRfnymrn8ol?ISM^_`Qn%WNa@Odx#1v10GaZHT-!2c(}Q_ zxB*sxzsDY1JzfB(zsLCaVCDaBF%Sogg#4En2*k$=V@&@g#=!vs|D`_=hyw;q{#zaw z|6gJtE)b0D{I@(FuD^@{;{FdY9*)0`%gYH{jeqG6_9(F8`1cqm-@lIm0&#<3{O8}= zLS4;ZTZil8zY1vB`oLxywr{CBJ3|4#X;9dv@Vka#9seeT-JoW!(BG2};^cs(($Pt& HNTdEgDhN>H literal 0 HcmV?d00001 diff --git a/rig-lancedb/CHANGELOG.md b/rig-lancedb/CHANGELOG.md new file mode 100644 index 00000000..50fd907a --- /dev/null +++ b/rig-lancedb/CHANGELOG.md @@ -0,0 +1,55 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.1.0](https://github.com/0xPlaygrounds/rig/releases/tag/rig-lancedb-v0.1.0) - 2024-10-24 + +### Added + +- update examples to use new version of VectorStoreIndex trait +- replace document embeddings with serde json value +- merge all arrow columns into JSON document in deserializer +- finish implementing deserialiser for record batch +- implement deserialization for any recordbatch returned from lanceDB +- add indexes and tables for simple search +- create enum for embedding models +- add vector_search_s3_ann example +- implement ANN search example +- start implementing top_n_from_query for trait VectorStoreIndex +- implement get_document method of VectorStore trait +- implement search by id for VectorStore trait +- implement add_documents on VectorStore trait +- start implementing VectorStore trait for lancedb + +### Fixed + +- update lancedb examples test data +- make PR changes Pt II +- make PR changes pt I +- *(lancedb)* replace VectorStoreIndexDyn with VectorStoreIndex in examples +- mongodb vector search - use num_candidates from search params +- fix bug in deserializing type run end +- make PR requested changes +- reduce opanai generated content in ANN examples + +### Other + +- cargo fmt +- lance db examples +- add example docstring +- add doc strings +- update rig core version on lancedb crate, remove implementation of VectorStore trait +- remove print statement +- use constants instead of enum for model names +- remove associated type on VectorStoreIndex trait +- cargo fmt +- conversions from arrow types to primitive types +- Add doc strings to utility methods +- add doc string to mongodb search params struct +- Merge branch 'main' into feat(vector-store)/lancedb +- create wrapper for vec for from/tryfrom traits diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index 031df2d3..1d68943c 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -2,10 +2,14 @@ name = "rig-lancedb" version = "0.1.0" edition = "2021" +license = "MIT" +readme = "README.md" +description = "Rig vector store index integration for LanceDB." +repository = "https://github.com/0xPlaygrounds/rig" [dependencies] lancedb = "0.10.0" -rig-core = { path = "../rig-core", version = "0.2.1" } +rig-core = { path = "../rig-core", version = "0.3.0" } arrow-array = "52.2.0" serde_json = "1.0.128" serde = "1.0.210" diff --git a/rig-lancedb/LICENSE b/rig-lancedb/LICENSE new file mode 100644 index 00000000..878b5fbc --- /dev/null +++ b/rig-lancedb/LICENSE @@ -0,0 +1,7 @@ +Copyright (c) 2024, Playgrounds Analytics Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/rig-lancedb/README.md b/rig-lancedb/README.md new file mode 100644 index 00000000..d9b647b6 --- /dev/null +++ b/rig-lancedb/README.md @@ -0,0 +1,2 @@ +# Rig-lancedb +Vector store index integration for LanceDB diff --git a/rig-mongodb/CHANGELOG.md b/rig-mongodb/CHANGELOG.md index 062f0dc1..8760c2f6 100644 --- a/rig-mongodb/CHANGELOG.md +++ b/rig-mongodb/CHANGELOG.md @@ -7,6 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.3](https://github.com/0xPlaygrounds/rig/compare/rig-mongodb-v0.1.2...rig-mongodb-v0.1.3) - 2024-10-24 + +### Fixed + +- make PR changes pt I +- mongodb vector search - use num_candidates from search params + +### Other + +- Merge branch 'main' into feat(vector-store)/lancedb + ## [0.1.2](https://github.com/0xPlaygrounds/rig/compare/rig-mongodb-v0.1.1...rig-mongodb-v0.1.2) - 2024-10-01 ### Other diff --git a/rig-mongodb/Cargo.toml b/rig-mongodb/Cargo.toml index 6f313838..76943d14 100644 --- a/rig-mongodb/Cargo.toml +++ b/rig-mongodb/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-mongodb" -version = "0.1.2" +version = "0.1.3" edition = "2021" license = "MIT" readme = "README.md" @@ -12,7 +12,7 @@ repository = "https://github.com/0xPlaygrounds/rig" [dependencies] futures = "0.3.30" mongodb = "2.8.2" -rig-core = { path = "../rig-core", version = "0.2.1" } +rig-core = { path = "../rig-core", version = "0.3.0" } serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.117" tracing = "0.1.40" diff --git a/rig-mongodb/README.md b/rig-mongodb/README.md index 018d146f..6bce6bcb 100644 --- a/rig-mongodb/README.md +++ b/rig-mongodb/README.md @@ -1,2 +1,34 @@ -# Rig-mongodb -This project implements a Rig vector store based on MongoDB. \ No newline at end of file + + +
\ + + + + Rig logo + + + + + + + MongoDB logo + +
+ +

+ +## Rig-MongoDB +This companion crate implements a Rig vector store based on MongoDB. + +## Usage + +Add the companion crate to your `Cargo.toml`, along with the rig-core crate: + +```toml +[dependencies] +rig-mongodb = "0.1.2" +rig-core = "0.2.1" +``` + +You can also run `cargo add rig-mongodb rig-core` to add the most recent versions of the dependencies to your project. + +See the [examples](./examples) folder for usage examples. diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 3d062de3..282628ee 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,12 +1,12 @@ use mongodb::{options::ClientOptions, Client as MongoClient, Collection}; -use std::env; - +use rig::vector_store::VectorStore; use rig::{ embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{VectorStore, VectorStoreIndex}, + vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; +use std::env; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { From 6619f7fcfcd5bbad541181272ddd30d695271714 Mon Sep 17 00:00:00 2001 From: Mathieu Date: Fri, 1 Nov 2024 19:09:12 +0100 Subject: [PATCH 14/18] fix(gemini): missing param to be marked as optional in completion res --- rig-core/examples/gemini_agent.rs | 23 ++++- rig-core/src/providers/gemini/completion.rs | 109 +++++++++++--------- rig-core/src/providers/gemini/embedding.rs | 3 +- 3 files changed, 81 insertions(+), 54 deletions(-) diff --git a/rig-core/examples/gemini_agent.rs b/rig-core/examples/gemini_agent.rs index 3e8e66e9..061a2d47 100644 --- a/rig-core/examples/gemini_agent.rs +++ b/rig-core/examples/gemini_agent.rs @@ -3,8 +3,17 @@ use rig::{ providers::gemini::{self, completion::gemini_api_types::GenerationConfig}, }; +use std::panic; +use tracing::debug; + +#[tracing::instrument(ret)] #[tokio::main] async fn main() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .with_target(false) + .init(); + // Initialize the Google Gemini client let client = gemini::Client::from_env(); @@ -13,7 +22,6 @@ async fn main() -> Result<(), anyhow::Error> { .agent(gemini::completion::GEMINI_1_5_PRO) .preamble("Be creative and concise. Answer directly and clearly.") .temperature(0.5) - .max_tokens(8192) .additional_params(serde_json::to_value(GenerationConfig { top_k: Some(1), top_p: Some(0.95), @@ -27,8 +35,17 @@ async fn main() -> Result<(), anyhow::Error> { // Prompt the agent and print the response let response = agent .prompt("How much wood would a woodchuck chuck if a woodchuck could chuck wood? Infer an answer.") - .await?; - println!("{}", response); + .await; + + tracing::info!("Response: {:?}", response); + + match response { + Ok(response) => println!("{}", response), + Err(e) => { + tracing::error!("Error: {:?}", e); + return Err(e.into()); + } + } Ok(()) } diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index f8bee2ec..aca5a661 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -18,10 +18,7 @@ use gemini_api_types::{ }; use std::convert::TryFrom; -use crate::{ - completion::{self, CompletionError, CompletionRequest}, - providers::gemini::client::ApiResponse, -}; +use crate::completion::{self, CompletionError, CompletionRequest}; use super::Client; @@ -75,19 +72,14 @@ impl completion::CompletionModel for CompletionModel { generation_config.max_output_tokens = Some(max_tokens); } - /* - serde_json::to_value(GenerationConfig { - top_k: Some(1), - top_p: Some(0.95), - candidate_count: Some(1), - ..Default::default() - })? */ - let request = GenerateContentRequest { contents: full_history .into_iter() .map(|msg| Content { - parts: vec![Part::Text(msg.content)], + parts: vec![Part { + text: Some(msg.content), + ..Default::default() + }], role: match msg.role.as_str() { "system" => Some(Role::Model), "user" => Some(Role::User), @@ -107,12 +99,15 @@ impl completion::CompletionModel for CompletionModel { ), tool_config: None, system_instruction: Some(Content { - parts: vec![Part::Text("system".to_string())], + parts: vec![Part { + text: Some("system".to_string()), + ..Default::default() + }], role: Some(Role::Model), }), }; - tracing::info!("Request: {:?}", request); + tracing::debug!("Sending completion request to Gemini API"); let response = self .client @@ -121,13 +116,12 @@ impl completion::CompletionModel for CompletionModel { .send() .await? .error_for_status()? - .json::>() + .json::() .await?; - match response { - ApiResponse::Ok(response) => Ok(response.try_into()?), - ApiResponse::Err(err) => Err(CompletionError::ResponseError(err.message)), - } + tracing::debug!("Received response"); + + completion::CompletionResponse::try_from(response) } } @@ -151,8 +145,13 @@ impl TryFrom for completion::CompletionResponse Ok(completion::CompletionResponse { choice: match content.parts.first().unwrap() { - Part::Text(text) => completion::ModelChoice::Message(text.clone()), - Part::FunctionCall(function_call) => { + Part { + text: Some(text), .. + } => completion::ModelChoice::Message(text.clone()), + Part { + function_call: Some(function_call), + .. + } => { let args_value = serde_json::Value::Object( function_call.args.clone().unwrap_or_default(), ); @@ -200,43 +199,41 @@ impl TryFrom for GenerationConfig { for (key, value) in obj.iter().filter(|(_, v)| !v.is_null()) { match key.as_str() { "temperature" => { - if !value.is_null() { - if let Some(v) = value.as_f64() { - config.temperature = Some(v); - } else { - return Err(unexpected_type_error("temperature")); - } + if let Some(v) = value.as_f64() { + config.temperature = Some(v); + } else { + return Err(unexpected_type_error("temperature")); } } - "max_output_tokens" => { + "maxOutputTokens" => { if let Some(v) = value.as_u64() { config.max_output_tokens = Some(v); } else { return Err(unexpected_type_error("max_output_tokens")); } } - "top_p" => { + "topP" => { if let Some(v) = value.as_f64() { config.top_p = Some(v); } else { return Err(unexpected_type_error("top_p")); } } - "top_k" => { + "topK" => { if let Some(v) = value.as_i64() { config.top_k = Some(v as i32); } else { return Err(unexpected_type_error("top_k")); } } - "candidate_count" => { + "candidateCount" => { if let Some(v) = value.as_i64() { config.candidate_count = Some(v as i32); } else { return Err(unexpected_type_error("candidate_count")); } } - "stop_sequences" => { + "stopSequences" => { if let Some(v) = value.as_array() { config.stop_sequences = Some( v.iter() @@ -247,31 +244,31 @@ impl TryFrom for GenerationConfig { return Err(unexpected_type_error("stop_sequences")); } } - "response_mime_type" => { + "responseMimeType" => { if let Some(v) = value.as_str() { config.response_mime_type = Some(v.to_string()); } else { return Err(unexpected_type_error("response_mime_type")); } } - "response_schema" => { + "responseSchema" => { config.response_schema = Some(value.clone().try_into()?); } - "presence_penalty" => { + "presencePenalty" => { if let Some(v) = value.as_f64() { config.presence_penalty = Some(v); } else { return Err(unexpected_type_error("presence_penalty")); } } - "frequency_penalty" => { + "frequencyPenalty" => { if let Some(v) = value.as_f64() { config.frequency_penalty = Some(v); } else { return Err(unexpected_type_error("frequency_penalty")); } } - "response_logprobs" => { + "responseLogprobs" => { if let Some(v) = value.as_bool() { config.response_logprobs = Some(v); } else { @@ -304,7 +301,6 @@ impl TryFrom for GenerationConfig { } pub mod gemini_api_types { - use std::collections::HashMap; // ================================================================= @@ -334,6 +330,7 @@ pub mod gemini_api_types { pub prompt_feedback: Option, /// Output only. Metadata on the generation requests' token usage. pub usage_metadata: Option, + pub model_version: Option, } /// A response candidate generated from the model. @@ -361,7 +358,6 @@ pub mod gemini_api_types { /// Output only. Index of the candidate in the list of response candidates. pub index: Option, } - #[derive(Debug, Deserialize, Serialize)] pub struct Content { /// Ordered Parts that constitute a single message. Parts may have different MIME types. @@ -372,6 +368,7 @@ pub mod gemini_api_types { } #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "lowercase")] pub enum Role { User, Model, @@ -380,16 +377,23 @@ pub mod gemini_api_types { /// A datatype containing media that is part of a multi-part [Content](Content) message. /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data. /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes. - #[derive(Debug, Deserialize, Serialize)] + #[derive(Debug, Default, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] - pub enum Part { - Text(String), - InlineData(Blob), - FunctionCall(FunctionCall), - FunctionResponse(FunctionResponse), - FileData(FileData), - ExecutableCode(ExecutableCode), - CodeExecutionResult(CodeExecutionResult), + pub struct Part { + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub inline_data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_response: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub file_data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub executable_code: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub code_execution_result: Option, } /// Raw media bytes. @@ -430,6 +434,7 @@ pub mod gemini_api_types { /// URI based data. #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "camelCase")] pub struct FileData { /// Optional. The IANA standard MIME type of the source data. pub mime_type: Option, @@ -474,7 +479,7 @@ pub mod gemini_api_types { #[serde(rename_all = "camelCase")] pub struct UsageMetadata { pub prompt_token_count: i32, - pub cached_content_token_count: i32, + pub cached_content_token_count: Option, pub candidates_token_count: i32, pub total_token_count: i32, } @@ -533,11 +538,13 @@ pub mod gemini_api_types { } #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] pub struct CitationMetadata { pub citation_sources: Vec, } #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] pub struct CitationSource { pub uri: Option, pub start_index: Option, @@ -546,6 +553,7 @@ pub mod gemini_api_types { } #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] pub struct LogprobsResult { pub top_candidate: Vec, pub chosen_candidate: Vec, @@ -567,6 +575,7 @@ pub mod gemini_api_types { /// Gemini API Configuration options for model generation and outputs. Not all parameters are /// configurable for every model. https://ai.google.dev/api/generate-content#generationconfig #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "camelCase")] pub struct GenerationConfig { /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response. diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs index 0a866bf0..ae249b69 100644 --- a/rig-core/src/providers/gemini/embedding.rs +++ b/rig-core/src/providers/gemini/embedding.rs @@ -43,8 +43,9 @@ impl embeddings::EmbeddingModel for EmbeddingModel { async fn embed_documents( &self, - documents: Vec, + documents: impl IntoIterator + Send, ) -> Result, EmbeddingError> { + let documents: Vec<_> = documents.into_iter().collect(); let mut request_body = json!({ "model": format!("models/{}", self.model), "content": { From 3fb10d607f41f41549f8a1810d308165e82c8989 Mon Sep 17 00:00:00 2001 From: Mathieu Date: Sun, 3 Nov 2024 11:41:24 +0100 Subject: [PATCH 15/18] fix: docs imports and refs --- rig-core/examples/gemini_agent.rs | 5 +---- rig-core/src/providers/gemini/client.rs | 8 ++++---- rig-core/src/providers/gemini/completion.rs | 10 +++++----- rig-core/src/providers/gemini/embedding.rs | 2 +- 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/rig-core/examples/gemini_agent.rs b/rig-core/examples/gemini_agent.rs index 061a2d47..882921ab 100644 --- a/rig-core/examples/gemini_agent.rs +++ b/rig-core/examples/gemini_agent.rs @@ -2,12 +2,9 @@ use rig::{ completion::Prompt, providers::gemini::{self, completion::gemini_api_types::GenerationConfig}, }; - -use std::panic; -use tracing::debug; - #[tracing::instrument(ret)] #[tokio::main] + async fn main() -> Result<(), anyhow::Error> { tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs index 0055af33..c22e6996 100644 --- a/rig-core/src/providers/gemini/client.rs +++ b/rig-core/src/providers/gemini/client.rs @@ -109,15 +109,15 @@ impl Client { } /// Create a completion model with the given name. - /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::GenerationConfig) struct. - /// https://ai.google.dev/api/generate-content#generationconfig + /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct. + /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig) pub fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } /// Create an agent builder with the given completion model. - /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::GenerationConfig) struct. - /// https://ai.google.dev/api/generate-content#generationconfig + /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct. + /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig) /// # Example /// ``` /// use rig::providers::gemini::{Client, self}; diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index aca5a661..06efb2b0 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -1,6 +1,6 @@ // ================================================================ //! Google Gemini Completion Integration -//! https://ai.google.dev/api/generate-content +//! From [Gemini API Reference](https://ai.google.dev/api/generate-content) // ================================================================ /// `gemini-1.5-flash` completion model @@ -374,7 +374,7 @@ pub mod gemini_api_types { Model, } - /// A datatype containing media that is part of a multi-part [Content](Content) message. + /// A datatype containing media that is part of a multi-part [Content] message. /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data. /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes. #[derive(Debug, Default, Deserialize, Serialize)] @@ -573,7 +573,7 @@ pub mod gemini_api_types { } /// Gemini API Configuration options for model generation and outputs. Not all parameters are - /// configurable for every model. https://ai.google.dev/api/generate-content#generationconfig + /// configurable for every model. From [[Gemini API Reference]](https://ai.google.dev/api/generate-content#generationconfig) #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct GenerationConfig { @@ -653,7 +653,7 @@ pub mod gemini_api_types { } /// The Schema object allows the definition of input and output data types. These types can be objects, but also /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object. - /// https://ai.google.dev/api/caching#Schema + /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema) #[derive(Debug, Deserialize, Serialize)] pub struct Schema { pub r#type: String, @@ -748,7 +748,7 @@ pub mod gemini_api_types { /// to learn how to incorporate safety considerations in your AI applications. pub safety_settings: Option>, /// Optional. Developer set system instruction(s). Currently, text only. - /// https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest + /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest) pub system_instruction: Option, // cachedContent: Optional } diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs index ae249b69..1249387a 100644 --- a/rig-core/src/providers/gemini/embedding.rs +++ b/rig-core/src/providers/gemini/embedding.rs @@ -1,6 +1,6 @@ // ================================================================ //! Google Gemini Embeddings Integration -//! https://ai.google.dev/api/embeddings +//! From [Gemini API Reference](https://ai.google.dev/api/embeddings) // ================================================================ use serde_json::json; From 43530e5cdd471d8d17cf1b1dd8bfcc888283594a Mon Sep 17 00:00:00 2001 From: Mathieu Date: Mon, 4 Nov 2024 20:12:54 +0100 Subject: [PATCH 16/18] fix(gemini): issue when additionnal param is empty --- rig-core/src/providers/gemini/completion.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index 06efb2b0..44449f9d 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -16,6 +16,7 @@ use gemini_api_types::{ Content, ContentCandidate, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, GenerationConfig, Part, Role, Tool, }; +use serde_json::{Map, Value}; use std::convert::TryFrom; use crate::completion::{self, CompletionError, CompletionRequest}; @@ -59,8 +60,10 @@ impl completion::CompletionModel for CompletionModel { }); // Handle Gemini specific parameters - let mut generation_config = - GenerationConfig::try_from(completion_request.additional_params.unwrap_or_default())?; + let additional_params = completion_request + .additional_params + .unwrap_or_else(|| Value::Object(Map::new())); + let mut generation_config = GenerationConfig::try_from(additional_params)?; // Set temperature from completion_request or additional_params if let Some(temp) = completion_request.temperature { From 1a444ad8d127db7fdf148b32ec59a26d11d01354 Mon Sep 17 00:00:00 2001 From: Mathieu Date: Mon, 4 Nov 2024 22:39:14 +0100 Subject: [PATCH 17/18] refactor(gemini): remove try_from and use serde deserialization --- rig-core/src/providers/gemini/completion.rs | 130 +------------------- 1 file changed, 1 insertion(+), 129 deletions(-) diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index 44449f9d..c2a58ceb 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -63,7 +63,7 @@ impl completion::CompletionModel for CompletionModel { let additional_params = completion_request .additional_params .unwrap_or_else(|| Value::Object(Map::new())); - let mut generation_config = GenerationConfig::try_from(additional_params)?; + let mut generation_config = serde_json::from_value::(additional_params)?; // Set temperature from completion_request or additional_params if let Some(temp) = completion_request.temperature { @@ -175,134 +175,6 @@ impl TryFrom for completion::CompletionResponse for GenerationConfig { - type Error = CompletionError; - - fn try_from(value: serde_json::Value) -> Result { - let mut config = GenerationConfig { - temperature: None, - max_output_tokens: None, - stop_sequences: None, - response_mime_type: None, - response_schema: None, - candidate_count: None, - top_p: None, - top_k: None, - presence_penalty: None, - frequency_penalty: None, - response_logprobs: None, - logprobs: None, - }; - - fn unexpected_type_error(field: &str) -> CompletionError { - CompletionError::ResponseError(format!("Unexpected type for field '{}'", field)) - } - - if let Some(obj) = value.as_object() { - for (key, value) in obj.iter().filter(|(_, v)| !v.is_null()) { - match key.as_str() { - "temperature" => { - if let Some(v) = value.as_f64() { - config.temperature = Some(v); - } else { - return Err(unexpected_type_error("temperature")); - } - } - "maxOutputTokens" => { - if let Some(v) = value.as_u64() { - config.max_output_tokens = Some(v); - } else { - return Err(unexpected_type_error("max_output_tokens")); - } - } - "topP" => { - if let Some(v) = value.as_f64() { - config.top_p = Some(v); - } else { - return Err(unexpected_type_error("top_p")); - } - } - "topK" => { - if let Some(v) = value.as_i64() { - config.top_k = Some(v as i32); - } else { - return Err(unexpected_type_error("top_k")); - } - } - "candidateCount" => { - if let Some(v) = value.as_i64() { - config.candidate_count = Some(v as i32); - } else { - return Err(unexpected_type_error("candidate_count")); - } - } - "stopSequences" => { - if let Some(v) = value.as_array() { - config.stop_sequences = Some( - v.iter() - .filter_map(|s| s.as_str().map(String::from)) - .collect(), - ); - } else { - return Err(unexpected_type_error("stop_sequences")); - } - } - "responseMimeType" => { - if let Some(v) = value.as_str() { - config.response_mime_type = Some(v.to_string()); - } else { - return Err(unexpected_type_error("response_mime_type")); - } - } - "responseSchema" => { - config.response_schema = Some(value.clone().try_into()?); - } - "presencePenalty" => { - if let Some(v) = value.as_f64() { - config.presence_penalty = Some(v); - } else { - return Err(unexpected_type_error("presence_penalty")); - } - } - "frequencyPenalty" => { - if let Some(v) = value.as_f64() { - config.frequency_penalty = Some(v); - } else { - return Err(unexpected_type_error("frequency_penalty")); - } - } - "responseLogprobs" => { - if let Some(v) = value.as_bool() { - config.response_logprobs = Some(v); - } else { - return Err(unexpected_type_error("response_logprobs")); - } - } - "logprobs" => { - if let Some(v) = value.as_i64() { - config.logprobs = Some(v as i32); - } else { - return Err(unexpected_type_error("logprobs")); - } - } - _ => { - tracing::warn!( - "Unknown GenerationConfig parameter, will be ignored: {}", - key - ); - } - } - } - } else { - return Err(CompletionError::ResponseError( - "Expected a JSON object for GenerationConfig".into(), - )); - } - - Ok(config) - } -} - pub mod gemini_api_types { use std::collections::HashMap; From 7ef411279f614aaa425b36351e58dd1a8c87b175 Mon Sep 17 00:00:00 2001 From: Mathieu Date: Mon, 4 Nov 2024 22:50:23 +0100 Subject: [PATCH 18/18] docs(gemini): add utility config docstring --- rig-core/examples/gemini_agent.rs | 1 + rig-core/src/providers/gemini/completion.rs | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/rig-core/examples/gemini_agent.rs b/rig-core/examples/gemini_agent.rs index 882921ab..785309a3 100644 --- a/rig-core/examples/gemini_agent.rs +++ b/rig-core/examples/gemini_agent.rs @@ -19,6 +19,7 @@ async fn main() -> Result<(), anyhow::Error> { .agent(gemini::completion::GEMINI_1_5_PRO) .preamble("Be creative and concise. Answer directly and clearly.") .temperature(0.5) + // The `GenerationConfig` utility struct helps construct a typesafe `additional_params` .additional_params(serde_json::to_value(GenerationConfig { top_k: Some(1), top_p: Some(0.95), diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index c2a58ceb..e54e03b6 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -448,7 +448,9 @@ pub mod gemini_api_types { } /// Gemini API Configuration options for model generation and outputs. Not all parameters are - /// configurable for every model. From [[Gemini API Reference]](https://ai.google.dev/api/generate-content#generationconfig) + /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig) + /// ### Rig Note: + /// Can be used to cosntruct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder). #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct GenerationConfig {