Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(model-providers): 16 Add Gemini Completion and Embedding Models #56

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
3605437
feat(provider-gemini): add gemini API client
Oct 11, 2024
3aed980
feat(provider-gemini): add gemini support for basic completion
Oct 11, 2024
45132c2
feat(provider-gemini): add gemini embedding support
Oct 11, 2024
8b69e90
feat(provider-gemini): add agent support in client
Oct 14, 2024
700a196
docs(provider-gemini): Update readme entries, add gemini agent example
Oct 14, 2024
cc74aec
style(provider-gemini): test pre-commits
Oct 14, 2024
d0e0ade
feat(provider-gemini): add support for gemini specific completion par…
Oct 14, 2024
e5e763e
docs(gemini): add addtionnal types from the official documentation, a…
Oct 14, 2024
1dca1da
feat(gemini): move system prompt to correct request field
mateobelanger Oct 15, 2024
05b5df1
docs(readme): remove gemini mention in non-exhaustive list
mateobelanger Oct 15, 2024
5b45c5c
chore: add debug trait to embedding struct
mateobelanger Oct 15, 2024
9ae5e33
refactor(gemini): separate gemini api types module, fix pr comments
mateobelanger Oct 31, 2024
caee495
Merge branch 'main' into feat/model-provider/16-add-gemini-completion…
mateobelanger Nov 1, 2024
6619f7f
fix(gemini): missing param to be marked as optional in completion res
mateobelanger Nov 1, 2024
a1be7c4
Merge remote-tracking branch 'origin/main' into feat/model-provider/1…
mateobelanger Nov 2, 2024
3fb10d6
fix: docs imports and refs
mateobelanger Nov 3, 2024
43530e5
fix(gemini): issue when additionnal param is empty
mateobelanger Nov 4, 2024
1a444ad
refactor(gemini): remove try_from and use serde deserialization
mateobelanger Nov 4, 2024
7ef4112
docs(gemini): add utility config docstring
mateobelanger Nov 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ jobs:
with:
command: nextest
args: run --all-features
env:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
Expand All @@ -105,4 +105,4 @@ jobs:
- name: Run cargo doc
run: cargo doc --no-deps --all-features
env:
RUSTDOCFLAGS: -D warnings
RUSTDOCFLAGS: -D warnings
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Rig supports the following LLM providers natively:
- Cohere
- Anthropic
- Perplexity
- Google Gemini

Additionally, Rig currently has the following integration sub-libraries:
- MongoDB vector store: `rig-mongodb`
Expand Down
2 changes: 1 addition & 1 deletion rig-core/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ or just `full` to enable all features (`cargo add tokio --features macros,rt-mul
Rig supports the following LLM providers natively:
- OpenAI
- Cohere

- Google Gemini
Additionally, Rig currently has the following integration sub-libraries:
- MongoDB vector store: `rig-mongodb`
49 changes: 49 additions & 0 deletions rig-core/examples/gemini_agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use rig::{
completion::Prompt,
providers::gemini::{self, completion::gemini_api_types::GenerationConfig},
};
#[tracing::instrument(ret)]
#[tokio::main]

async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_target(false)
.init();

// Initialize the Google Gemini client
let client = gemini::Client::from_env();

// Create agent with a single context prompt
let agent = client
.agent(gemini::completion::GEMINI_1_5_PRO)
.preamble("Be creative and concise. Answer directly and clearly.")
.temperature(0.5)
// The `GenerationConfig` utility struct helps construct a typesafe `additional_params`
.additional_params(serde_json::to_value(GenerationConfig {
top_k: Some(1),
top_p: Some(0.95),
candidate_count: Some(1),
..Default::default()
})?) // Unwrap the Result to get the Value
mateobelanger marked this conversation as resolved.
Show resolved Hide resolved
.build();

tracing::info!("Prompting the agent...");

// Prompt the agent and print the response
let response = agent
.prompt("How much wood would a woodchuck chuck if a woodchuck could chuck wood? Infer an answer.")
.await;

tracing::info!("Response: {:?}", response);

match response {
Ok(response) => println!("{}", response),
Err(e) => {
tracing::error!("Error: {:?}", e);
return Err(e.into());
}
}

Ok(())
}
20 changes: 20 additions & 0 deletions rig-core/examples/gemini_embeddings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use rig::providers::gemini;

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize the Google Gemini client
// Create OpenAI client
let client = gemini::Client::from_env();

let embeddings = client
.embeddings(gemini::embedding::EMBEDDING_001)
.simple_document("doc0", "Hello, world!")
.simple_document("doc1", "Goodbye, world!")
.build()
.await
.expect("Failed to embed documents");

println!("{:?}", embeddings);

Ok(())
}
8 changes: 4 additions & 4 deletions rig-core/src/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
//!
//! // Create an embeddings builder and add documents
//! let embeddings = EmbeddingsBuilder::new(embedding_model)
//! .simple_document("doc1", "This is the first document.")
//! .simple_document("doc1", "This is the first document.")
//! .simple_document("doc2", "This is the second document.")
//! .build()
//! .await
//! .expect("Failed to build embeddings.");
//!
//!
//! // Use the generated embeddings
//! // ...
//! ```
Expand Down Expand Up @@ -102,7 +102,7 @@ pub trait EmbeddingModel: Clone + Sync + Send {
}

/// Struct that holds a single document and its embedding.
#[derive(Clone, Default, Deserialize, Serialize)]
#[derive(Clone, Default, Deserialize, Serialize, Debug)]
pub struct Embedding {
/// The document that was embedded
pub document: String,
Expand Down Expand Up @@ -142,7 +142,7 @@ impl Embedding {
/// large document to be retrieved from a query that matches multiple smaller and
/// distinct text documents. For example, if the document is a textbook, a summary of
/// each chapter could serve as the book's embeddings.
#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)]
pub struct DocumentEmbeddings {
#[serde(rename = "_id")]
pub id: String,
Expand Down
156 changes: 156 additions & 0 deletions rig-core/src/providers/gemini/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
use crate::{
agent::AgentBuilder,
embeddings::{self},
extractor::ExtractorBuilder,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

use super::{completion::CompletionModel, embedding::EmbeddingModel};

// ================================================================
// Google Gemini Client
// ================================================================
const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";

#[derive(Clone)]
pub struct Client {
base_url: String,
api_key: String,
http_client: reqwest::Client,
}

impl Client {
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, GEMINI_API_BASE_URL)
}
fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
api_key: api_key.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers
})
.build()
.expect("Gemini reqwest client should build"),
}
}

/// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
Self::new(&api_key)
}

pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/");
mateobelanger marked this conversation as resolved.
Show resolved Hide resolved

tracing::debug!("POST {}", url);
self.http_client.post(url)
}

/// Create an embedding model with the given name.
/// Note: default embedding dimension of 0 will be used if model is not known.
/// If this is the case, it's better to use function `embedding_model_with_ndims`
///
/// # Example
/// ```
/// use rig::providers::gemini::{Client, self};
///
/// // Initialize the Google Gemini client
/// let gemini = Client::new("your-google-gemini-api-key");
///
/// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001);
/// ```
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, None)
}

/// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
///
/// # Example
/// ```
/// use rig::providers::gemini::{Client, self};
///
/// // Initialize the Google Gemini client
/// let gemini = Client::new("your-google-gemini-api-key");
///
/// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024);
/// ```
pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, Some(ndims))
}

/// Create an embedding builder with the given embedding model.
///
/// # Example
/// ```
/// use rig::providers::gemini::{Client, self};
///
/// // Initialize the Google Gemini client
/// let gemini = Client::new("your-google-gemini-api-key");
///
/// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001)
/// .simple_document("doc0", "Hello, world!")
/// .simple_document("doc1", "Goodbye, world!")
/// .build()
/// .await
/// .expect("Failed to embed documents");
/// ```
pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder<EmbeddingModel> {
embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
}

/// Create a completion model with the given name.
/// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
/// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}

/// Create an agent builder with the given completion model.
/// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
/// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
/// # Example
/// ```
/// use rig::providers::gemini::{Client, self};
///
/// // Initialize the Google Gemini client
/// let gemini = Client::new("your-google-gemini-api-key");
///
/// let agent = gemini.agent(gemini::completion::GEMINI_1_5_PRO)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}

/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}

#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
pub message: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
Loading
Loading