From 45132c23a4624c2cca932ad0e48c0a3159d33ab9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20B=C3=A9langer?= Date: Fri, 11 Oct 2024 19:55:38 +0200 Subject: [PATCH] feat(provider-gemini): add gemini embedding support --- rig-core/src/providers/gemini/embedding.rs | 95 ++++++++++++++++++++++ rig-core/src/providers/gemini/mod.rs | 16 ++++ 2 files changed, 111 insertions(+) create mode 100644 rig-core/src/providers/gemini/embedding.rs create mode 100644 rig-core/src/providers/gemini/mod.rs diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs new file mode 100644 index 00000000..4b4ea2dc --- /dev/null +++ b/rig-core/src/providers/gemini/embedding.rs @@ -0,0 +1,95 @@ + +// ================================================================ +// Google Gemini Embeddings +// ================================================================ + +use serde::Deserialize; +use serde_json::json; + +use crate::embeddings::{self, EmbeddingError}; + +use super::{client::ApiResponse, Client}; + +/// `embedding-gecko-001` embedding model +pub const EMBEDDING_GECKO_001: &str = "embedding-gecko-001"; +/// `embedding-001` embedding model +pub const EMBEDDING_001: &str = "embedding-001"; +/// `text-embedding-004` embedding model +pub const EMBEDDING_004: &str = "text-embedding-004"; + +#[derive(Debug, Deserialize)] +pub struct EmbeddingResponse { + pub embedding: EmbeddingValues, +} + +#[derive(Debug, Deserialize)] +pub struct EmbeddingValues { + pub values: Vec, +} + +#[derive(Clone)] +pub struct EmbeddingModel { + client: Client, + model: String, + ndims: Option, +} + +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: Option) -> Self { + Self { client, model: model.to_string(), ndims } + } + + +} + +impl embeddings::EmbeddingModel for EmbeddingModel { + const MAX_DOCUMENTS: usize = 1024; + + fn ndims(&self) -> usize { + match self.model.as_str() { + EMBEDDING_GECKO_001 | EMBEDDING_001 => 768, + EMBEDDING_004 => 1024, + _ => 0, // Default to 0 for unknown models + } + } + + + async fn embed_documents( + &self, + documents: Vec, + ) -> Result, EmbeddingError> { + let mut request_body = json!({ + "model": format!("models/{}", self.model), + "content": { + "parts": documents.iter().map(|doc| json!({ "text": doc })).collect::>(), + }, + }); + + if let Some(ndims) = self.ndims { + request_body["output_dimensionality"] = json!(ndims); + } + + let response = self + .client + .post(&format!("/v1beta/models/{}:embedContent", self.model)) + .json(&request_body) + .send() + .await? + .error_for_status()? + .json::>() + .await?; + + match response { + ApiResponse::Ok(response) => { + let chunk_size = self.ndims.unwrap_or_else(|| self.ndims()); + Ok(documents.into_iter().zip(response.embedding.values.chunks(chunk_size)).map(|(document, embedding)| { + embeddings::Embedding { + document, + vec: embedding.to_vec(), + } + }).collect()) + } + ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), + } + } +} diff --git a/rig-core/src/providers/gemini/mod.rs b/rig-core/src/providers/gemini/mod.rs new file mode 100644 index 00000000..994fca36 --- /dev/null +++ b/rig-core/src/providers/gemini/mod.rs @@ -0,0 +1,16 @@ +//! Google API client and Rig integration +//! +//! # Example +//! ``` +//! use rig::providers::google; +//! +//! let client = google::Client::new("YOUR_API_KEY"); +//! +//! let gemini_embedding_model = client.embedding_model(google::EMBEDDING_001); +//! ``` + +pub mod client; +pub mod completion; +pub mod embedding; + +pub use client::Client;