Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(model-providers): 16 Add Gemini Completion and Embedding Models #56

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
3605437
feat(provider-gemini): add gemini API client
Oct 11, 2024
3aed980
feat(provider-gemini): add gemini support for basic completion
Oct 11, 2024
45132c2
feat(provider-gemini): add gemini embedding support
Oct 11, 2024
8b69e90
feat(provider-gemini): add agent support in client
Oct 14, 2024
700a196
docs(provider-gemini): Update readme entries, add gemini agent example
Oct 14, 2024
cc74aec
style(provider-gemini): test pre-commits
Oct 14, 2024
d0e0ade
feat(provider-gemini): add support for gemini specific completion par…
Oct 14, 2024
e5e763e
docs(gemini): add addtionnal types from the official documentation, a…
Oct 14, 2024
1dca1da
feat(gemini): move system prompt to correct request field
mateobelanger Oct 15, 2024
05b5df1
docs(readme): remove gemini mention in non-exhaustive list
mateobelanger Oct 15, 2024
5b45c5c
chore: add debug trait to embedding struct
mateobelanger Oct 15, 2024
9ae5e33
refactor(gemini): separate gemini api types module, fix pr comments
mateobelanger Oct 31, 2024
caee495
Merge branch 'main' into feat/model-provider/16-add-gemini-completion…
mateobelanger Nov 1, 2024
6619f7f
fix(gemini): missing param to be marked as optional in completion res
mateobelanger Nov 1, 2024
a1be7c4
Merge remote-tracking branch 'origin/main' into feat/model-provider/1…
mateobelanger Nov 2, 2024
3fb10d6
fix: docs imports and refs
mateobelanger Nov 3, 2024
43530e5
fix(gemini): issue when additionnal param is empty
mateobelanger Nov 4, 2024
1a444ad
refactor(gemini): remove try_from and use serde deserialization
mateobelanger Nov 4, 2024
7ef4112
docs(gemini): add utility config docstring
mateobelanger Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ We'd love your feedback. Please take a moment to let us know what you think usin

## High-level features
- Full support for LLM completion and embedding workflows
- Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere) and vector stores (e.g. MongoDB, in-memory)
- Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere, Google Gemini) and vector stores (e.g. MongoDB, in-memory)
mateobelanger marked this conversation as resolved.
Show resolved Hide resolved
- Integrate LLMs in your app with minimal boilerplate

## Installation
Expand Down Expand Up @@ -70,6 +70,6 @@ or just `full` to enable all features (`cargo add tokio --features macros,rt-mul
Rig supports the following LLM providers natively:
- OpenAI
- Cohere

- Google Gemini
Additionally, Rig currently has the following integration sub-libraries:
- MongoDB vector store: `rig-mongodb`
4 changes: 2 additions & 2 deletions rig-core/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ More information about this crate can be found in the [crate documentation](http

## High-level features
- Full support for LLM completion and embedding workflows
- Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere) and vector stores (e.g. MongoDB, in-memory)
- Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere, Google Gemini) and vector stores (e.g. MongoDB, in-memory)
mateobelanger marked this conversation as resolved.
Show resolved Hide resolved
- Integrate LLMs in your app with minimal boilerplate

## Installation
Expand Down Expand Up @@ -47,6 +47,6 @@ or just `full` to enable all features (`cargo add tokio --features macros,rt-mul
Rig supports the following LLM providers natively:
- OpenAI
- Cohere

- Google Gemini
Additionally, Rig currently has the following integration sub-libraries:
- MongoDB vector store: `rig-mongodb`
36 changes: 36 additions & 0 deletions rig-core/examples/gemini_agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use rig::{
completion::Prompt,
providers::gemini::{self, completion::GenerationConfig},
};

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize the Google Gemini client
// Create OpenAI client
mateobelanger marked this conversation as resolved.
Show resolved Hide resolved
let client = gemini::Client::from_env();

// Create agent with a single context prompt
let agent = client
.agent(gemini::completion::GEMINI_1_5_PRO)
.preamble("Be precise and concise.")
.temperature(0.5)
.max_tokens(8192)
.additional_params(
serde_json::to_value(GenerationConfig {
top_k: Some(1),
top_p: Some(0.95),
candidate_count: Some(1),
..Default::default()
})
.unwrap(),
) // Unwrap the Result to get the Value
.build();

// Prompt the agent and print the response
let response = agent
.prompt("How much wood would a woodchuck chuck if a woodchuck could chuck wood?")
.await?;
println!("{}", response);

Ok(())
}
156 changes: 156 additions & 0 deletions rig-core/src/providers/gemini/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
use crate::{
agent::AgentBuilder,
embeddings::{self},
extractor::ExtractorBuilder,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

use super::{completion::CompletionModel, embedding::EmbeddingModel};

// ================================================================
// Google Gemini Client
// ================================================================
const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";

#[derive(Clone)]
pub struct Client {
base_url: String,
api_key: String,
http_client: reqwest::Client,
}

impl Client {
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, GEMINI_API_BASE_URL)
}
fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
api_key: api_key.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers
})
.build()
.expect("Gemini reqwest client should build"),
}
}

/// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
Self::new(&api_key)
}

pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/");
mateobelanger marked this conversation as resolved.
Show resolved Hide resolved

tracing::info!("POST {}", url);
mateobelanger marked this conversation as resolved.
Show resolved Hide resolved
self.http_client.post(url)
}

/// Create an embedding model with the given name.
/// Note: default embedding dimension of 0 will be used if model is not known.
/// If this is the case, it's better to use function `embedding_model_with_ndims`
///
/// # Example
/// ```
/// use rig::providers::gemini::{Client, self};
///
/// // Initialize the Google Gemini client
/// let gemini = Client::new("your-google-gemini-api-key");
///
/// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001);
/// ```
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, None)
}

/// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
///
/// # Example
/// ```
/// use rig::providers::gemini::{Client, self};
///
/// // Initialize the Google Gemini client
/// let gemini = Client::new("your-google-gemini-api-key");
///
/// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024);
/// ```
pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, Some(ndims))
}

/// Create an embedding builder with the given embedding model.
///
/// # Example
/// ```
/// use rig::providers::gemini::{Client, self};
///
/// // Initialize the Google Gemini client
/// let gemini = Client::new("your-google-gemini-api-key");
///
/// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001)
/// .simple_document("doc0", "Hello, world!")
/// .simple_document("doc1", "Goodbye, world!")
/// .build()
/// .await
/// .expect("Failed to embed documents");
/// ```
pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder<EmbeddingModel> {
embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
}

/// Create a completion model with the given name.
/// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::GenerationConfig) struct.
/// https://ai.google.dev/api/generate-content#generationconfig
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}

/// Create an agent builder with the given completion model.
/// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::GenerationConfig) struct.
/// https://ai.google.dev/api/generate-content#generationconfig
/// # Example
/// ```
/// use rig::providers::gemini::{Client, self};
///
/// // Initialize the Google Gemini client
/// let gemini = Client::new("your-google-gemini-api-key");
///
/// let agent = gemini.agent(gemini::completion::GEMINI_1_5_PRO)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}

/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}

#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
pub message: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
Loading
Loading