forked from 0xPlaygrounds/rig
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support moonshot language model (0xPlaygrounds#223)
* feat: support moonshot language model * feat: support moonshot language model * chore: update mootshot to moonshot * chore: add config attribute macro * chore: remove unsupported method
- Loading branch information
1 parent
3b6692b
commit 5c5fab8
Showing
4 changed files
with
320 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
use rig::agent::AgentBuilder; | ||
use rig::providers::moonshot::{CompletionModel, MOONSHOT_CHAT}; | ||
use rig::{completion::Prompt, providers}; | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<(), anyhow::Error> { | ||
println!("Running basic agent with moonshot"); | ||
basic_moonshot().await?; | ||
|
||
println!("\nRunning moonshot agent with context"); | ||
context_moonshot().await?; | ||
|
||
println!("\n\nAll agents ran successfully"); | ||
Ok(()) | ||
} | ||
|
||
fn client() -> providers::moonshot::Client { | ||
providers::moonshot::Client::from_env() | ||
} | ||
|
||
fn partial_agent_moonshot() -> AgentBuilder<CompletionModel> { | ||
let client = client(); | ||
client.agent(MOONSHOT_CHAT) | ||
} | ||
|
||
async fn basic_moonshot() -> Result<(), anyhow::Error> { | ||
let comedian_agent = partial_agent_moonshot() | ||
.preamble("You are a comedian here to entertain the user using humour and jokes.") | ||
.build(); | ||
|
||
// Prompt the agent and print the response | ||
let response = comedian_agent.prompt("Entertain me!").await?; | ||
println!("{}", response); | ||
|
||
Ok(()) | ||
} | ||
|
||
async fn context_moonshot() -> Result<(), anyhow::Error> { | ||
let model = client().completion_model(MOONSHOT_CHAT); | ||
|
||
// Create an agent with multiple context documents | ||
let agent = AgentBuilder::new(model) | ||
.preamble("Definition of a *glarb-glarb*: A glarb-glarb is an ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") | ||
.build(); | ||
|
||
// Prompt the agent and print the response | ||
let response = agent.prompt("What does \"glarb-glarb\" mean?").await?; | ||
|
||
println!("{}", response); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,263 @@ | ||
//! Moonshot API client and Rig integration | ||
//! | ||
//! # Example | ||
//! ``` | ||
//! use rig::providers::moonshot; | ||
//! | ||
//! let client = moonshot::Client::new("YOUR_API_KEY"); | ||
//! | ||
//! let moonshot_model = client.completion_model(moonshot::MOONSHOT_CHAT); | ||
//! ``` | ||
use crate::{ | ||
agent::AgentBuilder, | ||
completion::{self, CompletionError, CompletionRequest}, | ||
extractor::ExtractorBuilder, | ||
json_utils, | ||
}; | ||
use schemars::JsonSchema; | ||
use serde::{Deserialize, Serialize}; | ||
use serde_json::json; | ||
|
||
// ================================================================ | ||
// Main Moonshot Client | ||
// ================================================================ | ||
const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1"; | ||
|
||
#[derive(Clone)] | ||
pub struct Client { | ||
base_url: String, | ||
http_client: reqwest::Client, | ||
} | ||
|
||
impl Client { | ||
/// Create a new Moonshot client with the given API key. | ||
pub fn new(api_key: &str) -> Self { | ||
Self::from_url(api_key, MOONSHOT_API_BASE_URL) | ||
} | ||
|
||
/// Create a new Moonshot client with the given API key and base API URL. | ||
pub fn from_url(api_key: &str, base_url: &str) -> Self { | ||
Self { | ||
base_url: base_url.to_string(), | ||
http_client: reqwest::Client::builder() | ||
.default_headers({ | ||
let mut headers = reqwest::header::HeaderMap::new(); | ||
headers.insert( | ||
"Authorization", | ||
format!("Bearer {}", api_key) | ||
.parse() | ||
.expect("Bearer token should parse"), | ||
); | ||
headers | ||
}) | ||
.build() | ||
.expect("Moonshot reqwest client should build"), | ||
} | ||
} | ||
|
||
/// Create a new Moonshot client from the `MOONSHOT_API_KEY` environment variable. | ||
/// Panics if the environment variable is not set. | ||
pub fn from_env() -> Self { | ||
let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set"); | ||
Self::new(&api_key) | ||
} | ||
|
||
fn post(&self, path: &str) -> reqwest::RequestBuilder { | ||
let url = format!("{}/{}", self.base_url, path).replace("//", "/"); | ||
self.http_client.post(url) | ||
} | ||
|
||
/// Create a completion model with the given name. | ||
/// | ||
/// # Example | ||
/// ``` | ||
/// use rig::providers::moonshot::{Client, self}; | ||
/// | ||
/// // Initialize the Moonshot client | ||
/// let moonshot = Client::new("your-moonshot-api-key"); | ||
/// | ||
/// let completion_model = moonshot.completion_model(moonshot::MOONSHOT_CHAT); | ||
/// ``` | ||
pub fn completion_model(&self, model: &str) -> CompletionModel { | ||
CompletionModel::new(self.clone(), model) | ||
} | ||
|
||
/// Create an agent builder with the given completion model. | ||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> { | ||
AgentBuilder::new(self.completion_model(model)) | ||
} | ||
|
||
/// Create an extractor builder with the given completion model. | ||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>( | ||
&self, | ||
model: &str, | ||
) -> ExtractorBuilder<T, CompletionModel> { | ||
ExtractorBuilder::new(self.completion_model(model)) | ||
} | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
struct ApiErrorResponse { | ||
error: MoonshotError, | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
struct MoonshotError { | ||
message: String, | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
#[serde(untagged)] | ||
enum ApiResponse<T> { | ||
Ok(T), | ||
Err(ApiErrorResponse), | ||
} | ||
|
||
// ================================================================ | ||
// Moonshot Completion API | ||
// ================================================================ | ||
pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k"; | ||
|
||
#[derive(Debug, Deserialize)] | ||
pub struct CompletionResponse { | ||
pub id: String, | ||
pub object: String, | ||
pub created: u64, | ||
pub model: String, | ||
pub choices: Vec<Choice>, | ||
pub usage: Usage, | ||
} | ||
|
||
impl From<ApiErrorResponse> for CompletionError { | ||
fn from(err: ApiErrorResponse) -> Self { | ||
CompletionError::ProviderError(err.error.message) | ||
} | ||
} | ||
|
||
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> { | ||
type Error = CompletionError; | ||
|
||
fn try_from(value: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> { | ||
match value.choices.as_slice() { | ||
[Choice { | ||
message: | ||
Message { | ||
content: Some(content), | ||
.. | ||
}, | ||
.. | ||
}, ..] => Ok(completion::CompletionResponse { | ||
choice: completion::ModelChoice::Message(content.to_string()), | ||
raw_response: value, | ||
}), | ||
_ => Err(CompletionError::ResponseError( | ||
"Response did not contain a message".into(), | ||
)), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
pub struct Choice { | ||
pub index: usize, | ||
pub message: Message, | ||
pub finish_reason: String, | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
pub struct Message { | ||
pub role: String, | ||
pub content: Option<String>, | ||
} | ||
|
||
#[derive(Clone, Debug, Deserialize)] | ||
pub struct Usage { | ||
pub prompt_tokens: usize, | ||
pub total_tokens: usize, | ||
} | ||
|
||
impl std::fmt::Display for Usage { | ||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
write!( | ||
f, | ||
"Prompt tokens: {} Total tokens: {}", | ||
self.prompt_tokens, self.total_tokens | ||
) | ||
} | ||
} | ||
|
||
#[derive(Clone)] | ||
pub struct CompletionModel { | ||
client: Client, | ||
pub model: String, | ||
} | ||
|
||
impl CompletionModel { | ||
pub fn new(client: Client, model: &str) -> Self { | ||
Self { | ||
client, | ||
model: model.to_string(), | ||
} | ||
} | ||
} | ||
|
||
impl completion::CompletionModel for CompletionModel { | ||
type Response = CompletionResponse; | ||
|
||
#[cfg_attr(feature = "worker", worker::send)] | ||
async fn completion( | ||
&self, | ||
mut completion_request: CompletionRequest, | ||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> { | ||
let mut full_history = if let Some(preamble) = &completion_request.preamble { | ||
vec![completion::Message { | ||
role: "system".into(), | ||
content: preamble.clone(), | ||
}] | ||
} else { | ||
vec![] | ||
}; | ||
|
||
full_history.append(&mut completion_request.chat_history); | ||
|
||
full_history.push(completion::Message { | ||
role: "user".into(), | ||
content: completion_request.prompt_with_context(), | ||
}); | ||
|
||
let request = json!({ | ||
"model": self.model, | ||
"messages": full_history, | ||
"temperature": completion_request.temperature, | ||
}); | ||
|
||
let response = self | ||
.client | ||
.post("/chat/completions") | ||
.json( | ||
&if let Some(params) = completion_request.additional_params { | ||
json_utils::merge(request, params) | ||
} else { | ||
request | ||
}, | ||
) | ||
.send() | ||
.await?; | ||
|
||
if response.status().is_success() { | ||
match response.json::<ApiResponse<CompletionResponse>>().await? { | ||
ApiResponse::Ok(response) => { | ||
tracing::info!(target: "rig", | ||
"Moonshot completion token usage: {}", | ||
response.usage | ||
); | ||
response.try_into() | ||
} | ||
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)), | ||
} | ||
} else { | ||
Err(CompletionError::ProviderError(response.text().await?)) | ||
} | ||
} | ||
} |