Skip to content

Commit

Permalink
Merge pull request #56 from 0xPlaygrounds/feat/model-provider/16-add-…
Browse files Browse the repository at this point in the history
…gemini-completion-embedding-models

feat(model-providers): 16 Add Gemini Completion and Embedding Models
  • Loading branch information
cvauclair authored Nov 4, 2024
2 parents bdfe1ba + 7ef4112 commit e549b37
Show file tree
Hide file tree
Showing 12 changed files with 1,175 additions and 8 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ jobs:
with:
command: nextest
args: run --all-features
env:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
Expand All @@ -105,4 +105,4 @@ jobs:
- name: Run cargo doc
run: cargo doc --no-deps --all-features
env:
RUSTDOCFLAGS: -D warnings
RUSTDOCFLAGS: -D warnings
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Rig supports the following LLM providers natively:
- Cohere
- Anthropic
- Perplexity
- Google Gemini

Additionally, Rig currently has the following integration sub-libraries:
- MongoDB vector store: `rig-mongodb`
Expand Down
2 changes: 1 addition & 1 deletion rig-core/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ or just `full` to enable all features (`cargo add tokio --features macros,rt-mul
Rig supports the following LLM providers natively:
- OpenAI
- Cohere

- Google Gemini
Additionally, Rig currently has the following integration sub-libraries:
- MongoDB vector store: `rig-mongodb`
49 changes: 49 additions & 0 deletions rig-core/examples/gemini_agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use rig::{
completion::Prompt,
providers::gemini::{self, completion::gemini_api_types::GenerationConfig},
};
#[tracing::instrument(ret)]
#[tokio::main]

async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_target(false)
.init();

// Initialize the Google Gemini client
let client = gemini::Client::from_env();

// Create agent with a single context prompt
let agent = client
.agent(gemini::completion::GEMINI_1_5_PRO)
.preamble("Be creative and concise. Answer directly and clearly.")
.temperature(0.5)
// The `GenerationConfig` utility struct helps construct a typesafe `additional_params`
.additional_params(serde_json::to_value(GenerationConfig {
top_k: Some(1),
top_p: Some(0.95),
candidate_count: Some(1),
..Default::default()
})?) // Unwrap the Result to get the Value
.build();

tracing::info!("Prompting the agent...");

// Prompt the agent and print the response
let response = agent
.prompt("How much wood would a woodchuck chuck if a woodchuck could chuck wood? Infer an answer.")
.await;

tracing::info!("Response: {:?}", response);

match response {
Ok(response) => println!("{}", response),
Err(e) => {
tracing::error!("Error: {:?}", e);
return Err(e.into());
}
}

Ok(())
}
20 changes: 20 additions & 0 deletions rig-core/examples/gemini_embeddings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use rig::providers::gemini;

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize the Google Gemini client
// Create OpenAI client
let client = gemini::Client::from_env();

let embeddings = client
.embeddings(gemini::embedding::EMBEDDING_001)
.simple_document("doc0", "Hello, world!")
.simple_document("doc1", "Goodbye, world!")
.build()
.await
.expect("Failed to embed documents");

println!("{:?}", embeddings);

Ok(())
}
8 changes: 4 additions & 4 deletions rig-core/src/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
//!
//! // Create an embeddings builder and add documents
//! let embeddings = EmbeddingsBuilder::new(embedding_model)
//! .simple_document("doc1", "This is the first document.")
//! .simple_document("doc1", "This is the first document.")
//! .simple_document("doc2", "This is the second document.")
//! .build()
//! .await
//! .expect("Failed to build embeddings.");
//!
//!
//! // Use the generated embeddings
//! // ...
//! ```
Expand Down Expand Up @@ -102,7 +102,7 @@ pub trait EmbeddingModel: Clone + Sync + Send {
}

/// Struct that holds a single document and its embedding.
#[derive(Clone, Default, Deserialize, Serialize)]
#[derive(Clone, Default, Deserialize, Serialize, Debug)]
pub struct Embedding {
/// The document that was embedded
pub document: String,
Expand Down Expand Up @@ -142,7 +142,7 @@ impl Embedding {
/// large document to be retrieved from a query that matches multiple smaller and
/// distinct text documents. For example, if the document is a textbook, a summary of
/// each chapter could serve as the book's embeddings.
#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)]
pub struct DocumentEmbeddings {
#[serde(rename = "_id")]
pub id: String,
Expand Down
156 changes: 156 additions & 0 deletions rig-core/src/providers/gemini/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
use crate::{
agent::AgentBuilder,
embeddings::{self},
extractor::ExtractorBuilder,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

use super::{completion::CompletionModel, embedding::EmbeddingModel};

// ================================================================
// Google Gemini Client
// ================================================================
const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";

#[derive(Clone)]
pub struct Client {
base_url: String,
api_key: String,
http_client: reqwest::Client,
}

impl Client {
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, GEMINI_API_BASE_URL)
}
fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
api_key: api_key.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers
})
.build()
.expect("Gemini reqwest client should build"),
}
}

/// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
Self::new(&api_key)
}

pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/");

tracing::debug!("POST {}", url);
self.http_client.post(url)
}

/// Create an embedding model with the given name.
/// Note: default embedding dimension of 0 will be used if model is not known.
/// If this is the case, it's better to use function `embedding_model_with_ndims`
///
/// # Example
/// ```
/// use rig::providers::gemini::{Client, self};
///
/// // Initialize the Google Gemini client
/// let gemini = Client::new("your-google-gemini-api-key");
///
/// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001);
/// ```
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, None)
}

/// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
///
/// # Example
/// ```
/// use rig::providers::gemini::{Client, self};
///
/// // Initialize the Google Gemini client
/// let gemini = Client::new("your-google-gemini-api-key");
///
/// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024);
/// ```
pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, Some(ndims))
}

/// Create an embedding builder with the given embedding model.
///
/// # Example
/// ```
/// use rig::providers::gemini::{Client, self};
///
/// // Initialize the Google Gemini client
/// let gemini = Client::new("your-google-gemini-api-key");
///
/// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001)
/// .simple_document("doc0", "Hello, world!")
/// .simple_document("doc1", "Goodbye, world!")
/// .build()
/// .await
/// .expect("Failed to embed documents");
/// ```
pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder<EmbeddingModel> {
embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
}

/// Create a completion model with the given name.
/// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
/// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}

/// Create an agent builder with the given completion model.
/// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
/// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
/// # Example
/// ```
/// use rig::providers::gemini::{Client, self};
///
/// // Initialize the Google Gemini client
/// let gemini = Client::new("your-google-gemini-api-key");
///
/// let agent = gemini.agent(gemini::completion::GEMINI_1_5_PRO)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}

/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}

#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
pub message: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
Loading

0 comments on commit e549b37

Please sign in to comment.