Skip to content

Commit

Permalink
fix(message): rebase and fix modeling for new providers
Browse files Browse the repository at this point in the history
  • Loading branch information
0xMochan committed Feb 4, 2025
1 parent 7b9076c commit 39e1776
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 346 deletions.
159 changes: 31 additions & 128 deletions rig-core/src/providers/azure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use crate::{
completion::{self, CompletionError, CompletionRequest},
embeddings::{self, EmbeddingError, EmbeddingsBuilder},
extractor::ExtractorBuilder,
json_utils, Embed,
json_utils,
providers::openai,
Embed,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -349,110 +351,6 @@ pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
/// `gpt-3.5-turbo-16k` completion model
pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";

#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub system_fingerprint: Option<String>,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
}

impl From<ApiErrorResponse> for CompletionError {
fn from(err: ApiErrorResponse) -> Self {
CompletionError::ProviderError(err.message)
}
}

impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
type Error = CompletionError;

fn try_from(value: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> {
match value.choices.as_slice() {
[Choice {
message:
Message {
tool_calls: Some(calls),
..
},
..
}, ..]
if !calls.is_empty() =>
{
let call = calls.first().unwrap();

Ok(completion::CompletionResponse {
choice: completion::ModelChoice::ToolCall(
call.function.name.clone(),
"".to_owned(),
serde_json::from_str(&call.function.arguments)?,
),
raw_response: value,
})
}
[Choice {
message:
Message {
content: Some(content),
..
},
..
}, ..] => Ok(completion::CompletionResponse {
choice: completion::ModelChoice::Message(content.to_string()),
raw_response: value,
}),
_ => Err(CompletionError::ResponseError(
"Response did not contain a message or tool call".into(),
)),
}
}
}

#[derive(Debug, Deserialize)]
pub struct Choice {
pub index: usize,
pub message: Message,
pub logprobs: Option<serde_json::Value>,
pub finish_reason: String,
}

#[derive(Debug, Deserialize)]
pub struct Message {
pub role: String,
pub content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
}

#[derive(Debug, Deserialize)]
pub struct ToolCall {
pub id: String,
pub r#type: String,
pub function: Function,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolDefinition {
pub r#type: String,
pub function: completion::ToolDefinition,
}

impl From<completion::ToolDefinition> for ToolDefinition {
fn from(tool: completion::ToolDefinition) -> Self {
Self {
r#type: "function".into(),
function: tool,
}
}
}

#[derive(Debug, Deserialize)]
pub struct Function {
pub name: String,
pub arguments: String,
}

#[derive(Clone)]
pub struct CompletionModel {
client: Client,
Expand All @@ -470,46 +368,48 @@ impl CompletionModel {
}

impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
type Response = openai::CompletionResponse;

#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
mut completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
// Add preamble to chat history (if available)
// NOTE: Azure o1-preview models does not support system messages
let mut full_history = if let Some(preamble) = &completion_request.preamble {
vec![completion::Message {
role: "system".into(),
content: preamble.clone(),
}]
} else {
vec![]
let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
Some(preamble) => vec![openai::Message::system(preamble)],
None => vec![],
};

// Extend existing chat history
full_history.append(&mut completion_request.chat_history);
// Convert prompt to user message
let prompt: Vec<openai::Message> = completion_request.prompt_with_context().try_into()?;

// Add context documents to chat history
let prompt_with_context = completion_request.prompt_with_context();
// Convert existing chat history
let chat_history: Vec<openai::Message> = completion_request
.chat_history
.into_iter()
.map(|message| message.try_into())
.collect::<Result<Vec<Vec<openai::Message>>, _>>()?
.into_iter()
.flatten()
.collect();

// Add context documents to chat history
full_history.push(completion::Message {
role: "user".into(),
content: prompt_with_context,
});
// Combine all messages into a single history
full_history.extend(chat_history);
full_history.extend(prompt);

let request = if completion_request.tools.is_empty() {
json!({
"model": self.model,
"messages": full_history,
"temperature": completion_request.temperature,
})
} else {
json!({
"model": self.model,
"messages": full_history,
"temperature": completion_request.temperature,
"tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
"tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
"tool_choice": "auto",
})
};
Expand All @@ -528,7 +428,10 @@ impl completion::CompletionModel for CompletionModel {
.await?;

if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
let t = response.text().await?;
tracing::debug!(target: "rig", "Azure completion error: {}", t);

match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
ApiResponse::Ok(response) => {
tracing::info!(target: "rig",
"Azure completion token usage: {:?}",
Expand Down Expand Up @@ -577,7 +480,7 @@ mod azure_tests {
.completion(CompletionRequest {
preamble: Some("You are a helpful assistant.".to_string()),
chat_history: vec![],
prompt: "Hello, world!".to_string(),
prompt: "Hello, world!".into(),
documents: vec![],
max_tokens: Some(100),
temperature: Some(0.0),
Expand Down
Loading

0 comments on commit 39e1776

Please sign in to comment.