diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 903faf48..b65da281 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -79,7 +79,7 @@ jobs: with: command: nextest args: run --all-features - env: + env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} @@ -105,4 +105,4 @@ jobs: - name: Run cargo doc run: cargo doc --no-deps --all-features env: - RUSTDOCFLAGS: -D warnings \ No newline at end of file + RUSTDOCFLAGS: -D warnings diff --git a/README.md b/README.md index 2a9a264e..7a65fa35 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,7 @@ Rig supports the following LLM providers natively: - Cohere - Anthropic - Perplexity +- Google Gemini Additionally, Rig currently has the following integration sub-libraries: - MongoDB vector store: `rig-mongodb` diff --git a/rig-core/README.md b/rig-core/README.md index 1f0d9d69..567bb9e4 100644 --- a/rig-core/README.md +++ b/rig-core/README.md @@ -47,6 +47,6 @@ or just `full` to enable all features (`cargo add tokio --features macros,rt-mul Rig supports the following LLM providers natively: - OpenAI - Cohere - +- Google Gemini Additionally, Rig currently has the following integration sub-libraries: - MongoDB vector store: `rig-mongodb` diff --git a/rig-core/examples/gemini_agent.rs b/rig-core/examples/gemini_agent.rs new file mode 100644 index 00000000..785309a3 --- /dev/null +++ b/rig-core/examples/gemini_agent.rs @@ -0,0 +1,49 @@ +use rig::{ + completion::Prompt, + providers::gemini::{self, completion::gemini_api_types::GenerationConfig}, +}; +#[tracing::instrument(ret)] +#[tokio::main] + +async fn main() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .with_target(false) + .init(); + + // Initialize the Google Gemini client + let client = gemini::Client::from_env(); + + // Create agent with a single context prompt + let agent = client + .agent(gemini::completion::GEMINI_1_5_PRO) + .preamble("Be creative and concise. Answer directly and clearly.") + .temperature(0.5) + // The `GenerationConfig` utility struct helps construct a typesafe `additional_params` + .additional_params(serde_json::to_value(GenerationConfig { + top_k: Some(1), + top_p: Some(0.95), + candidate_count: Some(1), + ..Default::default() + })?) // Unwrap the Result to get the Value + .build(); + + tracing::info!("Prompting the agent..."); + + // Prompt the agent and print the response + let response = agent + .prompt("How much wood would a woodchuck chuck if a woodchuck could chuck wood? Infer an answer.") + .await; + + tracing::info!("Response: {:?}", response); + + match response { + Ok(response) => println!("{}", response), + Err(e) => { + tracing::error!("Error: {:?}", e); + return Err(e.into()); + } + } + + Ok(()) +} diff --git a/rig-core/examples/gemini_embeddings.rs b/rig-core/examples/gemini_embeddings.rs new file mode 100644 index 00000000..4ce24636 --- /dev/null +++ b/rig-core/examples/gemini_embeddings.rs @@ -0,0 +1,20 @@ +use rig::providers::gemini; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Initialize the Google Gemini client + // Create OpenAI client + let client = gemini::Client::from_env(); + + let embeddings = client + .embeddings(gemini::embedding::EMBEDDING_001) + .simple_document("doc0", "Hello, world!") + .simple_document("doc1", "Goodbye, world!") + .build() + .await + .expect("Failed to embed documents"); + + println!("{:?}", embeddings); + + Ok(()) +} diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index 262d98ca..eaced08b 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -28,12 +28,12 @@ //! //! // Create an embeddings builder and add documents //! let embeddings = EmbeddingsBuilder::new(embedding_model) -//! .simple_document("doc1", "This is the first document.") +//! .simple_document("doc1", "This is the first document.") //! .simple_document("doc2", "This is the second document.") //! .build() //! .await //! .expect("Failed to build embeddings."); -//! +//! //! // Use the generated embeddings //! // ... //! ``` @@ -102,7 +102,7 @@ pub trait EmbeddingModel: Clone + Sync + Send { } /// Struct that holds a single document and its embedding. -#[derive(Clone, Default, Deserialize, Serialize)] +#[derive(Clone, Default, Deserialize, Serialize, Debug)] pub struct Embedding { /// The document that was embedded pub document: String, @@ -142,7 +142,7 @@ impl Embedding { /// large document to be retrieved from a query that matches multiple smaller and /// distinct text documents. For example, if the document is a textbook, a summary of /// each chapter could serve as the book's embeddings. -#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] pub struct DocumentEmbeddings { #[serde(rename = "_id")] pub id: String, diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs new file mode 100644 index 00000000..c22e6996 --- /dev/null +++ b/rig-core/src/providers/gemini/client.rs @@ -0,0 +1,156 @@ +use crate::{ + agent::AgentBuilder, + embeddings::{self}, + extractor::ExtractorBuilder, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use super::{completion::CompletionModel, embedding::EmbeddingModel}; + +// ================================================================ +// Google Gemini Client +// ================================================================ +const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com"; + +#[derive(Clone)] +pub struct Client { + base_url: String, + api_key: String, + http_client: reqwest::Client, +} + +impl Client { + pub fn new(api_key: &str) -> Self { + Self::from_url(api_key, GEMINI_API_BASE_URL) + } + fn from_url(api_key: &str, base_url: &str) -> Self { + Self { + base_url: base_url.to_string(), + api_key: api_key.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::CONTENT_TYPE, + "application/json".parse().unwrap(), + ); + headers + }) + .build() + .expect("Gemini reqwest client should build"), + } + } + + /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable. + /// Panics if the environment variable is not set. + pub fn from_env() -> Self { + let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set"); + Self::new(&api_key) + } + + pub fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/"); + + tracing::debug!("POST {}", url); + self.http_client.post(url) + } + + /// Create an embedding model with the given name. + /// Note: default embedding dimension of 0 will be used if model is not known. + /// If this is the case, it's better to use function `embedding_model_with_ndims` + /// + /// # Example + /// ``` + /// use rig::providers::gemini::{Client, self}; + /// + /// // Initialize the Google Gemini client + /// let gemini = Client::new("your-google-gemini-api-key"); + /// + /// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001); + /// ``` + pub fn embedding_model(&self, model: &str) -> EmbeddingModel { + EmbeddingModel::new(self.clone(), model, None) + } + + /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model. + /// + /// # Example + /// ``` + /// use rig::providers::gemini::{Client, self}; + /// + /// // Initialize the Google Gemini client + /// let gemini = Client::new("your-google-gemini-api-key"); + /// + /// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024); + /// ``` + pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel { + EmbeddingModel::new(self.clone(), model, Some(ndims)) + } + + /// Create an embedding builder with the given embedding model. + /// + /// # Example + /// ``` + /// use rig::providers::gemini::{Client, self}; + /// + /// // Initialize the Google Gemini client + /// let gemini = Client::new("your-google-gemini-api-key"); + /// + /// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001) + /// .simple_document("doc0", "Hello, world!") + /// .simple_document("doc1", "Goodbye, world!") + /// .build() + /// .await + /// .expect("Failed to embed documents"); + /// ``` + pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder { + embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) + } + + /// Create a completion model with the given name. + /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct. + /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig) + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } + + /// Create an agent builder with the given completion model. + /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct. + /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig) + /// # Example + /// ``` + /// use rig::providers::gemini::{Client, self}; + /// + /// // Initialize the Google Gemini client + /// let gemini = Client::new("your-google-gemini-api-key"); + /// + /// let agent = gemini.agent(gemini::completion::GEMINI_1_5_PRO) + /// .preamble("You are comedian AI with a mission to make people laugh.") + /// .temperature(0.0) + /// .build(); + /// ``` + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + /// Create an extractor builder with the given completion model. + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } +} + +#[derive(Debug, Deserialize)] +pub struct ApiErrorResponse { + pub message: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum ApiResponse { + Ok(T), + Err(ApiErrorResponse), +} diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs new file mode 100644 index 00000000..e54e03b6 --- /dev/null +++ b/rig-core/src/providers/gemini/completion.rs @@ -0,0 +1,675 @@ +// ================================================================ +//! Google Gemini Completion Integration +//! From [Gemini API Reference](https://ai.google.dev/api/generate-content) +// ================================================================ + +/// `gemini-1.5-flash` completion model +pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash"; +/// `gemini-1.5-pro` completion model +pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro"; +/// `gemini-1.5-pro-8b` completion model +pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b"; +/// `gemini-1.0-pro` completion model +pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro"; + +use gemini_api_types::{ + Content, ContentCandidate, FunctionDeclaration, GenerateContentRequest, + GenerateContentResponse, GenerationConfig, Part, Role, Tool, +}; +use serde_json::{Map, Value}; +use std::convert::TryFrom; + +use crate::completion::{self, CompletionError, CompletionRequest}; + +use super::Client; + +// ================================================================= +// Rig Implementation Types +// ================================================================= + +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = GenerateContentResponse; + + async fn completion( + &self, + mut completion_request: CompletionRequest, + ) -> Result, CompletionError> { + let mut full_history = Vec::new(); + full_history.append(&mut completion_request.chat_history); + + let prompt_with_context = completion_request.prompt_with_context(); + + full_history.push(completion::Message { + role: "user".into(), + content: prompt_with_context, + }); + + // Handle Gemini specific parameters + let additional_params = completion_request + .additional_params + .unwrap_or_else(|| Value::Object(Map::new())); + let mut generation_config = serde_json::from_value::(additional_params)?; + + // Set temperature from completion_request or additional_params + if let Some(temp) = completion_request.temperature { + generation_config.temperature = Some(temp); + } + + // Set max_tokens from completion_request or additional_params + if let Some(max_tokens) = completion_request.max_tokens { + generation_config.max_output_tokens = Some(max_tokens); + } + + let request = GenerateContentRequest { + contents: full_history + .into_iter() + .map(|msg| Content { + parts: vec![Part { + text: Some(msg.content), + ..Default::default() + }], + role: match msg.role.as_str() { + "system" => Some(Role::Model), + "user" => Some(Role::User), + "assistant" => Some(Role::Model), + _ => None, + }, + }) + .collect(), + generation_config: Some(generation_config), + safety_settings: None, + tools: Some( + completion_request + .tools + .into_iter() + .map(Tool::from) + .collect(), + ), + tool_config: None, + system_instruction: Some(Content { + parts: vec![Part { + text: Some("system".to_string()), + ..Default::default() + }], + role: Some(Role::Model), + }), + }; + + tracing::debug!("Sending completion request to Gemini API"); + + let response = self + .client + .post(&format!("/v1beta/models/{}:generateContent", self.model)) + .json(&request) + .send() + .await? + .error_for_status()? + .json::() + .await?; + + tracing::debug!("Received response"); + + completion::CompletionResponse::try_from(response) + } +} + +impl From for Tool { + fn from(tool: completion::ToolDefinition) -> Self { + Self { + function_declaration: FunctionDeclaration { + name: tool.name, + description: tool.description, + parameters: None, // tool.parameters, TODO: Map Gemini + }, + code_execution: None, + } + } +} + +impl TryFrom for completion::CompletionResponse { + type Error = CompletionError; + + fn try_from(response: GenerateContentResponse) -> Result { + match response.candidates.as_slice() { + [ContentCandidate { content, .. }, ..] => Ok(completion::CompletionResponse { + choice: match content.parts.first().unwrap() { + Part { + text: Some(text), .. + } => completion::ModelChoice::Message(text.clone()), + Part { + function_call: Some(function_call), + .. + } => { + let args_value = serde_json::Value::Object( + function_call.args.clone().unwrap_or_default(), + ); + completion::ModelChoice::ToolCall(function_call.name.clone(), args_value) + } + _ => { + return Err(CompletionError::ResponseError( + "Unsupported response by the model of type ".into(), + )) + } + }, + raw_response: response, + }), + _ => Err(CompletionError::ResponseError( + "No candidates found in response".into(), + )), + } + } +} + +pub mod gemini_api_types { + use std::collections::HashMap; + + // ================================================================= + // Gemini API Types + // ================================================================= + use serde::{Deserialize, Serialize}; + use serde_json::{Map, Value}; + + use crate::{ + completion::CompletionError, + providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode}, + }; + + /// Response from the model supporting multiple candidate responses. + /// Safety ratings and content filtering are reported for both prompt in GenerateContentResponse.prompt_feedback + /// and for each candidate in finishReason and in safetyRatings. + /// The API: + /// - Returns either all requested candidates or none of them + /// - Returns no candidates at all only if there was something wrong with the prompt (check promptFeedback) + /// - Reports feedback on each candidate in finishReason and safetyRatings. + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct GenerateContentResponse { + /// Candidate responses from the model. + pub candidates: Vec, + /// Returns the prompt's feedback related to the content filters. + pub prompt_feedback: Option, + /// Output only. Metadata on the generation requests' token usage. + pub usage_metadata: Option, + pub model_version: Option, + } + + /// A response candidate generated from the model. + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct ContentCandidate { + /// Output only. Generated content returned from the model. + pub content: Content, + /// Optional. Output only. The reason why the model stopped generating tokens. + /// If empty, the model has not stopped generating tokens. + pub finish_reason: Option, + /// List of ratings for the safety of a response candidate. + /// There is at most one rating per category. + pub safety_ratings: Option>, + /// Output only. Citation information for model-generated candidate. + /// This field may be populated with recitation information for any text included in the content. + /// These are passages that are "recited" from copyrighted material in the foundational LLM's training data. + pub citation_metadata: Option, + /// Output only. Token count for this candidate. + pub token_count: Option, + /// Output only. + pub avg_logprobs: Option, + /// Output only. Log-likelihood scores for the response tokens and top tokens + pub logprobs_result: Option, + /// Output only. Index of the candidate in the list of response candidates. + pub index: Option, + } + #[derive(Debug, Deserialize, Serialize)] + pub struct Content { + /// Ordered Parts that constitute a single message. Parts may have different MIME types. + pub parts: Vec, + /// The producer of the content. Must be either 'user' or 'model'. + /// Useful to set for multi-turn conversations, otherwise can be left blank or unset. + pub role: Option, + } + + #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "lowercase")] + pub enum Role { + User, + Model, + } + + /// A datatype containing media that is part of a multi-part [Content] message. + /// A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data. + /// A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes. + #[derive(Debug, Default, Deserialize, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct Part { + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub inline_data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_response: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub file_data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub executable_code: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub code_execution_result: Option, + } + + /// Raw media bytes. + /// Text should not be sent as raw bytes, use the 'text' field. + #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct Blob { + /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg + /// If an unsupported MIME type is provided, an error will be returned. + pub mime_type: String, + /// Raw bytes for media formats. A base64-encoded string. + pub data: String, + } + + /// A predicted FunctionCall returned from the model that contains a string representing the + /// FunctionDeclaration.name with the arguments and their values. + /// #[derive(Debug, Deserialize, Serialize)] + #[derive(Debug, Deserialize, Serialize)] + pub struct FunctionCall { + /// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores + /// and dashes, with a maximum length of 63. + pub name: String, + /// Optional. The function parameters and values in JSON object format. + pub args: Option>, + } + + /// The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name + /// and a structured JSON object containing any output from the function is used as context to the model. + /// This should contain the result of aFunctionCall made based on model prediction. + #[derive(Debug, Deserialize, Serialize)] + pub struct FunctionResponse { + /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, + /// with a maximum length of 63. + pub name: String, + /// The function response in JSON object format. + pub response: Option>, + } + + /// URI based data. + #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct FileData { + /// Optional. The IANA standard MIME type of the source data. + pub mime_type: Option, + /// Required. URI. + pub file_uri: String, + } + + #[derive(Debug, Deserialize, Serialize)] + pub struct SafetyRating { + pub category: HarmCategory, + pub probability: HarmProbability, + } + + #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum HarmProbability { + HarmProbabilityUnspecified, + Negligible, + Low, + Medium, + High, + } + + #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum HarmCategory { + HarmCategoryUnspecified, + HarmCategoryDerogatory, + HarmCategoryToxicity, + HarmCategoryViolence, + HarmCategorySexually, + HarmCategoryMedical, + HarmCategoryDangerous, + HarmCategoryHarassment, + HarmCategoryHateSpeech, + HarmCategorySexuallyExplicit, + HarmCategoryDangerousContent, + HarmCategoryCivicIntegrity, + } + + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct UsageMetadata { + pub prompt_token_count: i32, + pub cached_content_token_count: Option, + pub candidates_token_count: i32, + pub total_token_count: i32, + } + + /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest). + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct PromptFeedback { + /// Optional. If set, the prompt was blocked and no candidates are returned. Rephrase the prompt. + pub block_reason: Option, + /// Ratings for safety of the prompt. There is at most one rating per category. + pub safety_ratings: Option>, + } + + /// Reason why a prompt was blocked by the model + #[derive(Debug, Deserialize)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum BlockReason { + /// Default value. This value is unused. + BlockReasonUnspecified, + /// Prompt was blocked due to safety reasons. Inspect safetyRatings to understand which safety category blocked it. + Safety, + /// Prompt was blocked due to unknown reasons. + Other, + /// Prompt was blocked due to the terms which are included from the terminology blocklist. + Blocklist, + /// Prompt was blocked due to prohibited content. + ProhibitedContent, + } + + #[derive(Debug, Deserialize)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum FinishReason { + /// Default value. This value is unused. + FinishReasonUnspecified, + /// Natural stop point of the model or provided stop sequence. + Stop, + /// The maximum number of tokens as specified in the request was reached. + MaxTokens, + /// The response candidate content was flagged for safety reasons. + Safety, + /// The response candidate content was flagged for recitation reasons. + Recitation, + /// The response candidate content was flagged for using an unsupported language. + Language, + /// Unknown reason. + Other, + /// Token generation stopped because the content contains forbidden terms. + Blocklist, + /// Token generation stopped for potentially containing prohibited content. + ProhibitedContent, + /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII). + Spii, + /// The function call generated by the model is invalid. + MalformedFunctionCall, + } + + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct CitationMetadata { + pub citation_sources: Vec, + } + + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct CitationSource { + pub uri: Option, + pub start_index: Option, + pub end_index: Option, + pub license: Option, + } + + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct LogprobsResult { + pub top_candidate: Vec, + pub chosen_candidate: Vec, + } + + #[derive(Debug, Deserialize)] + pub struct TopCandidate { + pub candidates: Vec, + } + + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct LogProbCandidate { + pub token: String, + pub token_id: String, + pub log_probability: f64, + } + + /// Gemini API Configuration options for model generation and outputs. Not all parameters are + /// configurable for every model. From [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig) + /// ### Rig Note: + /// Can be used to cosntruct a typesafe `additional_params` in rig::[AgentBuilder](crate::agent::AgentBuilder). + #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct GenerationConfig { + /// The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop + /// at the first appearance of a stop_sequence. The stop sequence will not be included as part of the response. + pub stop_sequences: Option>, + /// MIME type of the generated candidate text. Supported MIME types are: + /// - text/plain: (default) Text output + /// - application/json: JSON response in the response candidates. + /// - text/x.enum: ENUM as a string response in the response candidates. + /// Refer to the docs for a list of all supported text MIME types + pub response_mime_type: Option, + /// Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be + /// objects, primitives or arrays. If set, a compatible responseMimeType must also be set. Compatible MIME + /// types: application/json: Schema for JSON response. Refer to the JSON text generation guide for more details. + pub response_schema: Option, + /// Number of generated responses to return. Currently, this value can only be set to 1. If + /// unset, this will default to 1. + pub candidate_count: Option, + /// The maximum number of tokens to include in a response candidate. Note: The default value varies by model, see + /// the Model.output_token_limit attribute of the Model returned from the getModel function. + pub max_output_tokens: Option, + /// Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature + /// attribute of the Model returned from the getModel function. Values can range from [0.0, 2.0]. + pub temperature: Option, + /// The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and + /// Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most + /// likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while + /// Nucleus sampling limits the number of tokens based on the cumulative probability. Note: The default value + /// varies by Model and is specified by theModel.top_p attribute returned from the getModel function. An empty + /// topK attribute indicates that the model doesn't apply top-k sampling and doesn't allow setting topK on requests. + pub top_p: Option, + /// The maximum number of tokens to consider when sampling. Gemini models use Top-p (nucleus) sampling or a + /// combination of Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens. + /// Models running with nucleus sampling don't allow topK setting. Note: The default value varies by Model and is + /// specified by theModel.top_p attribute returned from the getModel function. An empty topK attribute indicates + /// that the model doesn't apply top-k sampling and doesn't allow setting topK on requests. + pub top_k: Option, + /// Presence penalty applied to the next token's logprobs if the token has already been seen in the response. + /// This penalty is binary on/off and not dependant on the number of times the token is used (after the first). + /// Use frequencyPenalty for a penalty that increases with each use. A positive penalty will discourage the use + /// of tokens that have already been used in the response, increasing the vocabulary. A negative penalty will + /// encourage the use of tokens that have already been used in the response, decreasing the vocabulary. + pub presence_penalty: Option, + /// Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been + /// seen in the respponse so far. A positive penalty will discourage the use of tokens that have already been + /// used, proportional to the number of times the token has been used: The more a token is used, the more + /// dificult it is for the model to use that token again increasing the vocabulary of responses. Caution: A + /// negative penalty will encourage the model to reuse tokens proportional to the number of times the token has + /// been used. Small negative values will reduce the vocabulary of a response. Larger negative values will cause + /// the model to repeating a common token until it hits the maxOutputTokens limit: "...the the the the the...". + pub frequency_penalty: Option, + /// If true, export the logprobs results in response. + pub response_logprobs: Option, + /// Only valid if responseLogprobs=True. This sets the number of top logprobs to return at each decoding step in + /// [Candidate.logprobs_result]. + pub logprobs: Option, + } + + impl Default for GenerationConfig { + fn default() -> Self { + Self { + temperature: Some(1.0), + max_output_tokens: Some(4096), + stop_sequences: None, + response_mime_type: None, + response_schema: None, + candidate_count: None, + top_p: None, + top_k: None, + presence_penalty: None, + frequency_penalty: None, + response_logprobs: None, + logprobs: None, + } + } + } + /// The Schema object allows the definition of input and output data types. These types can be objects, but also + /// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object. + /// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema) + #[derive(Debug, Deserialize, Serialize)] + pub struct Schema { + pub r#type: String, + pub format: Option, + pub description: Option, + pub nullable: Option, + pub r#enum: Option>, + pub max_items: Option, + pub min_items: Option, + pub properties: Option>, + pub required: Option>, + pub items: Option>, + } + + impl TryFrom for Schema { + type Error = CompletionError; + + fn try_from(value: Value) -> Result { + if let Some(obj) = value.as_object() { + Ok(Schema { + r#type: obj + .get("type") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(), + format: obj.get("format").and_then(|v| v.as_str()).map(String::from), + description: obj + .get("description") + .and_then(|v| v.as_str()) + .map(String::from), + nullable: obj.get("nullable").and_then(|v| v.as_bool()), + r#enum: obj.get("enum").and_then(|v| v.as_array()).map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }), + max_items: obj + .get("maxItems") + .and_then(|v| v.as_i64()) + .map(|v| v as i32), + min_items: obj + .get("minItems") + .and_then(|v| v.as_i64()) + .map(|v| v as i32), + properties: obj + .get("properties") + .and_then(|v| v.as_object()) + .map(|map| { + map.iter() + .filter_map(|(k, v)| { + v.clone().try_into().ok().map(|schema| (k.clone(), schema)) + }) + .collect() + }), + required: obj.get("required").and_then(|v| v.as_array()).map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }), + items: obj + .get("items") + .map(|v| Box::new(v.clone().try_into().unwrap())), + }) + } else { + Err(CompletionError::ResponseError( + "Expected a JSON object for Schema".into(), + )) + } + } + } + + #[derive(Debug, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct GenerateContentRequest { + pub contents: Vec, + pub tools: Option>, + pub tool_config: Option, + /// Optional. Configuration options for model generation and outputs. + pub generation_config: Option, + /// Optional. A list of unique SafetySetting instances for blocking unsafe content. This will be enforced on the + /// [GenerateContentRequest.contents] and [GenerateContentResponse.candidates]. There should not be more than one + /// setting for each SafetyCategory type. The API will block any contents and responses that fail to meet the + /// thresholds set by these settings. This list overrides the default settings for each SafetyCategory specified + /// in the safetySettings. If there is no SafetySetting for a given SafetyCategory provided in the list, the API + /// will use the default safety setting for that category. Harm categories: + /// - HARM_CATEGORY_HATE_SPEECH, + /// - HARM_CATEGORY_SEXUALLY_EXPLICIT + /// - HARM_CATEGORY_DANGEROUS_CONTENT + /// - HARM_CATEGORY_HARASSMENT + /// are supported. + /// Refer to the guide for detailed information on available safety settings. Also refer to the Safety guidance + /// to learn how to incorporate safety considerations in your AI applications. + pub safety_settings: Option>, + /// Optional. Developer set system instruction(s). Currently, text only. + /// From [Gemini API Reference](https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest) + pub system_instruction: Option, + // cachedContent: Optional + } + + #[derive(Debug, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct Tool { + pub function_declaration: FunctionDeclaration, + pub code_execution: Option, + } + + #[derive(Debug, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct FunctionDeclaration { + pub name: String, + pub description: String, + pub parameters: Option>, + } + + #[derive(Debug, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct ToolConfig { + pub schema: Option, + } + + #[derive(Debug, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct CodeExecution {} + + #[derive(Debug, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct SafetySetting { + pub category: HarmCategory, + pub threshold: HarmBlockThreshold, + } + + #[derive(Debug, Serialize)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum HarmBlockThreshold { + HarmBlockThresholdUnspecified, + BlockLowAndAbove, + BlockMediumAndAbove, + BlockOnlyHigh, + BlockNone, + Off, + } +} diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs new file mode 100644 index 00000000..1249387a --- /dev/null +++ b/rig-core/src/providers/gemini/embedding.rs @@ -0,0 +1,205 @@ +// ================================================================ +//! Google Gemini Embeddings Integration +//! From [Gemini API Reference](https://ai.google.dev/api/embeddings) +// ================================================================ + +use serde_json::json; + +use crate::embeddings::{self, EmbeddingError}; + +use super::{client::ApiResponse, Client}; + +/// `embedding-001` embedding model +pub const EMBEDDING_001: &str = "embedding-001"; +/// `text-embedding-004` embedding model +pub const EMBEDDING_004: &str = "text-embedding-004"; +#[derive(Clone)] +pub struct EmbeddingModel { + client: Client, + model: String, + ndims: Option, +} + +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: Option) -> Self { + Self { + client, + model: model.to_string(), + ndims, + } + } +} + +impl embeddings::EmbeddingModel for EmbeddingModel { + const MAX_DOCUMENTS: usize = 1024; + + fn ndims(&self) -> usize { + match self.model.as_str() { + EMBEDDING_001 => 768, + EMBEDDING_004 => 1024, + _ => 0, // Default to 0 for unknown models + } + } + + async fn embed_documents( + &self, + documents: impl IntoIterator + Send, + ) -> Result, EmbeddingError> { + let documents: Vec<_> = documents.into_iter().collect(); + let mut request_body = json!({ + "model": format!("models/{}", self.model), + "content": { + "parts": documents.iter().map(|doc| json!({ "text": doc })).collect::>(), + }, + }); + + if let Some(ndims) = self.ndims { + request_body["output_dimensionality"] = json!(ndims); + } + + let response = self + .client + .post(&format!("/v1beta/models/{}:embedContent", self.model)) + .json(&request_body) + .send() + .await? + .error_for_status()? + .json::>() + .await?; + + match response { + ApiResponse::Ok(response) => { + let chunk_size = self.ndims.unwrap_or_else(|| self.ndims()); + Ok(documents + .into_iter() + .zip(response.embedding.values.chunks(chunk_size)) + .map(|(document, embedding)| embeddings::Embedding { + document, + vec: embedding.to_vec(), + }) + .collect()) + } + ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), + } + } +} + +// ================================================================= +// Gemini API Types +// ================================================================= +/// Rust Implementation of the Gemini Types from [Gemini API Reference](https://ai.google.dev/api/embeddings) +#[allow(dead_code)] +mod gemini_api_types { + use serde::{Deserialize, Serialize}; + use serde_json::Value; + + use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode}; + + #[derive(Serialize)] + #[serde(rename_all = "camelCase")] + pub struct EmbedContentRequest { + model: String, + content: EmbeddingContent, + task_type: TaskType, + title: String, + output_dimensionality: i32, + } + + #[derive(Serialize)] + pub struct EmbeddingContent { + parts: Vec, + /// Optional. The producer of the content. Must be either 'user' or 'model'. Useful to set for multi-turn + /// conversations, otherwise can be left blank or unset. + role: Option, + } + + /// A datatype containing media that is part of a multi-part Content message. + /// - A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data. + /// - A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes. + #[derive(Serialize)] + pub struct EmbeddingContentPart { + /// Inline text. + text: String, + /// Inline media bytes. + inline_data: Option, + /// A predicted FunctionCall returned from the model that contains a string representing the [FunctionDeclaration.name] + /// with the arguments and their values. + function_call: Option, + /// The result output of a FunctionCall that contains a string representing the [FunctionDeclaration.name] and a structured + /// JSON object containing any output from the function is used as context to the model. + function_response: Option, + /// URI based data. + file_data: Option, + /// Code generated by the model that is meant to be executed. + executable_code: Option, + /// Result of executing the ExecutableCode. + code_execution_result: Option, + } + + /// Raw media bytes. + /// Text should not be sent as raw bytes, use the 'text' field. + #[derive(Serialize)] + pub struct Blob { + /// Raw bytes for media formats.A base64-encoded string. + data: String, + /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg If an unsupported MIME type is + /// provided, an error will be returned. For a complete list of supported types, see Supported file formats. + mime_type: String, + } + + #[derive(Serialize)] + pub struct FunctionCall { + /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63. + name: String, + /// The function parameters and values in JSON object format. + args: Option, + } + + #[derive(Serialize)] + pub struct FunctionResponse { + /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63. + name: String, + /// The result of the function call in JSON object format. + result: Value, + } + + #[derive(Serialize)] + #[serde(rename_all = "camelCase")] + pub struct FileData { + /// The URI of the file. + file_uri: String, + /// The IANA standard MIME type of the source data. + mime_type: String, + } + + #[derive(Serialize)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum TaskType { + /// Unset value, which will default to one of the other enum values. + Unspecified, + /// Specifies the given text is a query in a search/retrieval setting. + RetrievalQuery, + /// Specifies the given text is a document from the corpus being searched. + RetrievalDocument, + /// Specifies the given text will be used for STS. + SemanticSimilarity, + /// Specifies that the given text will be classified. + Classification, + /// Specifies that the embeddings will be used for clustering. + Clustering, + /// Specifies that the given text will be used for question answering. + QuestionAnswering, + /// Specifies that the given text will be used for fact verification. + FactVerification, + } + + #[derive(Debug, Deserialize)] + pub struct EmbeddingResponse { + pub embedding: EmbeddingValues, + } + + #[derive(Debug, Deserialize)] + pub struct EmbeddingValues { + pub values: Vec, + } +} diff --git a/rig-core/src/providers/gemini/mod.rs b/rig-core/src/providers/gemini/mod.rs new file mode 100644 index 00000000..108d2c6f --- /dev/null +++ b/rig-core/src/providers/gemini/mod.rs @@ -0,0 +1,59 @@ +//! Google Gemini API client and Rig integration +//! +//! # Example +//! ``` +//! use rig::providers::google; +//! +//! let client = google::Client::new("YOUR_API_KEY"); +//! +//! let gemini_embedding_model = client.embedding_model(google::EMBEDDING_001); +//! ``` + +pub mod client; +pub mod completion; +pub mod embedding; +pub use client::Client; + +pub mod gemini_api_types { + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize, Debug)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum ExecutionLanguage { + /// Unspecified language. This value should not be used. + LanguageUnspecified, + /// Python >= 3.10, with numpy and simpy available. + Python, + } + + /// Code generated by the model that is meant to be executed, and the result returned to the model. + /// Only generated when using the CodeExecution tool, in which the code will be automatically executed, + /// and a corresponding CodeExecutionResult will also be generated. + #[derive(Debug, Deserialize, Serialize)] + pub struct ExecutableCode { + /// Programming language of the code. + pub language: ExecutionLanguage, + /// The code to be executed. + pub code: String, + } + #[derive(Serialize, Deserialize, Debug)] + pub struct CodeExecutionResult { + /// Outcome of the code execution. + pub outcome: CodeExecutionOutcome, + /// Contains stdout when code execution is successful, stderr or other description otherwise. + pub output: Option, + } + + #[derive(Serialize, Deserialize, Debug)] + #[serde(rename_all = "SCREAMING_SNAKE_CASE")] + pub enum CodeExecutionOutcome { + /// Unspecified status. This value should not be used. + Unspecified, + /// Code execution completed successfully. + Ok, + /// Code execution finished but with a failure. stderr should contain the reason. + Failed, + /// Code execution ran for too long, and was cancelled. There may or may not be a partial output present. + DeadlineExceeded, + } +} diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 84bee958..9d774459 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -5,6 +5,7 @@ //! - OpenAI //! - Perplexity //! - Anthropic +//! - Google Gemini //! //! Each provider has its own module, which contains a `Client` implementation that can //! be used to initialize completion and embedding models and execute requests to those models. @@ -41,5 +42,6 @@ //! be used with the Cohere provider client. pub mod anthropic; pub mod cohere; +pub mod gemini; pub mod openai; pub mod perplexity; diff --git a/rig-lancedb/README.md b/rig-lancedb/README.md index b63c9740..d9b647b6 100644 --- a/rig-lancedb/README.md +++ b/rig-lancedb/README.md @@ -1,2 +1,2 @@ # Rig-lancedb -Vector store index integration for LanceDB \ No newline at end of file +Vector store index integration for LanceDB