Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(providers): Add Perplexity model provider #18

Merged
merged 3 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions rig-core/examples/perplexity_agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use std::env;

use rig::{
completion::Prompt,
providers::{self, perplexity::LLAMA_3_1_70B_INSTRUCT},
};
use serde_json::json;

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create OpenAI client
let client = providers::perplexity::Client::new(
&env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set"),
);

// Create agent with a single context prompt
let agent = client
.agent(LLAMA_3_1_70B_INSTRUCT)
.preamble("Be precise and concise.")
.temperature(0.5)
.additional_params(json!({
"return_related_questions": true,
"return_images": true
}))
.build();

// Prompt the agent and print the response
let response = agent
.prompt("When and where and what type is the next solar eclipse?")
.await?;
println!("{}", response);

Ok(())
}
2 changes: 2 additions & 0 deletions rig-core/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//! Currently, the following providers are supported:
//! - Cohere
//! - OpenAI
//! - Perplexity
//!
//! Each provider has its own module, which contains a `Client` implementation that can
//! be used to initialize completion and embedding models and execute requests to those models.
Expand Down Expand Up @@ -39,3 +40,4 @@
//! be used with the Cohere provider client.
pub mod cohere;
pub mod openai;
pub mod perplexity;
255 changes: 255 additions & 0 deletions rig-core/src/providers/perplexity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
//! Perplexity API client and Rig integration
//!
//! # Example
//! ```
//! use rig::providers::perplexity;
//!
//! let client = perplexity::Client::new("YOUR_API_KEY");
//!
//! let llama_3_1_sonar_small_online = client.completion_model(perplexity::LLAMA_3_1_SONAR_SMALL_ONLINE);
//! ```

use crate::{
agent::AgentBuilder,
completion::{self, CompletionError},
extractor::ExtractorBuilder,
json_utils,
model::ModelBuilder,
rag::RagAgentBuilder,
vector_store::{NoIndex, VectorStoreIndex},
};

use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;

// ================================================================
// Main Cohere Client
// ================================================================
const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai";

#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}

impl Client {
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, PERPLEXITY_API_BASE_URL)
}

/// Create a new Perplexity client from the `PERPLEXITY_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set");
Self::new(&api_key)
}

pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("Perplexity reqwest client should build"),
}
}

pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}

pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}

pub fn model(&self, model: &str) -> ModelBuilder<CompletionModel> {
ModelBuilder::new(self.completion_model(model))
}

pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}

pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}

pub fn rag_agent<C: VectorStoreIndex, T: VectorStoreIndex>(
&self,
model: &str,
) -> RagAgentBuilder<CompletionModel, C, T> {
RagAgentBuilder::new(self.completion_model(model))
}

pub fn context_rag_agent<C: VectorStoreIndex>(
&self,
model: &str,
) -> RagAgentBuilder<CompletionModel, C, NoIndex> {
RagAgentBuilder::new(self.completion_model(model))
}
}

#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
message: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}

// ================================================================
// Perplexity Completion API
// ================================================================
/// `llama-3.1-sonar-small-128k-online` completion model
pub const LLAMA_3_1_SONAR_SMALL_ONLINE: &str = "llama-3.1-sonar-small-128k-online";
/// `llama-3.1-sonar-large-128k-online` completion model
pub const LLAMA_3_1_SONAR_LARGE_ONLINE: &str = "llama-3.1-sonar-large-128k-online";
/// `llama-3.1-sonar-huge-128k-online` completion model
pub const LLAMA_3_1_SONAR_HUGE_ONLINE: &str = "llama-3.1-sonar-huge-128k-online";
/// `llama-3.1-sonar-small-128k-chat` completion model
pub const LLAMA_3_1_SONAR_SMALL_CHAT: &str = "llama-3.1-sonar-small-128k-chat";
/// `llama-3.1-sonar-large-128k-chat` completion model
pub const LLAMA_3_1_SONAR_LARGE_CHAT: &str = "llama-3.1-sonar-large-128k-chat";
/// `llama-3.1-8b-instruct` completion model
pub const LLAMA_3_1_8B_INSTRUCT: &str = "llama-3.1-8b-instruct";
/// `llama-3.1-70b-instruct` completion model
pub const LLAMA_3_1_70B_INSTRUCT: &str = "llama-3.1-70b-instruct";

#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub model: String,
pub object: String,
pub created: u64,
#[serde(default)]
pub choices: Vec<Choice>,
pub usage: Usage,
}

#[derive(Deserialize, Debug)]
pub struct Message {
pub role: String,
pub content: String,
}

#[derive(Deserialize, Debug)]
pub struct Delta {
pub role: String,
pub content: String,
}

#[derive(Deserialize, Debug)]
pub struct Choice {
pub index: usize,
pub finish_reason: String,
pub message: Message,
pub delta: Delta,
}

#[derive(Deserialize, Debug)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}

impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
type Error = CompletionError;

fn try_from(value: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> {
match value.choices.as_slice() {
[Choice {
message: Message { content, .. },
..
}, ..] => Ok(completion::CompletionResponse {
choice: completion::ModelChoice::Message(content.to_string()),
raw_response: value,
}),
_ => Err(CompletionError::ResponseError(
"Response did not contain a message or tool call".into(),
)),
}
}
}

#[derive(Clone)]
pub struct CompletionModel {
client: Client,
pub model: String,
}

impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}

impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;

async fn completion(
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let mut messages = completion_request.chat_history.clone();
if let Some(preamble) = completion_request.preamble {
messages.push(completion::Message {
role: "system".to_string(),
content: preamble,
});
}
messages.push(completion::Message {
role: "user".to_string(),
content: completion_request.prompt,
});

let request = json!({
"model": self.model,
"messages": messages,
"temperature": completion_request.temperature,
});

let response = self
.client
.post("/chat/completions")
.json(
&if let Some(ref params) = completion_request.additional_params {
json_utils::merge(request.clone(), params.clone())
} else {
request.clone()
},
)
.send()
.await?
.error_for_status()?
.json::<ApiResponse<CompletionResponse>>()
.await?;

match response {
ApiResponse::Ok(completion) => Ok(completion.try_into()?),
ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
}
}
}