From 6191482247d71e9e06d7db5af7eebbe5ed9c55a1 Mon Sep 17 00:00:00 2001 From: Sma1lboy <541898146chen@gmail.com> Date: Mon, 13 Jan 2025 20:34:15 -0500 Subject: [PATCH 1/6] feat(chat): add support for Azure API with versioning and refactor model handling --- crates/http-api-bindings/src/chat/mod.rs | 67 +++++++++++++++--------- crates/tabby-common/src/config.rs | 5 ++ crates/tabby-inference/src/chat.rs | 17 ++++++ 3 files changed, 64 insertions(+), 25 deletions(-) diff --git a/crates/http-api-bindings/src/chat/mod.rs b/crates/http-api-bindings/src/chat/mod.rs index bed9a15a6fa2..b0bb1d9aa651 100644 --- a/crates/http-api-bindings/src/chat/mod.rs +++ b/crates/http-api-bindings/src/chat/mod.rs @@ -12,31 +12,48 @@ pub async fn create(model: &HttpModelConfig) -> Arc { .api_endpoint .as_deref() .expect("api_endpoint is required"); - let config = OpenAIConfig::default() - .with_api_base(api_endpoint) - .with_api_key(model.api_key.clone().unwrap_or_default()); - - let mut builder = ExtendedOpenAIConfig::builder(); - - builder - .base(config) - .supported_models(model.supported_models.clone()) - .model_name(model.model_name.as_deref().expect("Model name is required")); - - if model.kind == "openai/chat" { - // Do nothing - } else if model.kind == "mistral/chat" { - builder.fields_to_remove(ExtendedOpenAIConfig::mistral_fields_to_remove()); - } else { - panic!("Unsupported model kind: {}", model.kind); - } - - let config = builder.build().expect("Failed to build config"); - - let engine = Box::new( - async_openai_alt::Client::with_config(config) - .with_http_client(create_reqwest_client(api_endpoint)), - ); + + let engine: Box = match model.kind.as_str() { + "azure/chat" => { + let config = async_openai_alt::config::AzureConfig::new() + .with_api_base(api_endpoint) + .with_api_key(model.api_key.clone().unwrap_or_default()) + .with_api_version( + model + .api_version + .clone() + .unwrap_or("2024-05-01-preview".to_string()), + ) + .with_deployment_id(model.model_name.as_deref().expect("Model name is required")); + Box::new( + async_openai_alt::Client::with_config(config) + .with_http_client(create_reqwest_client(api_endpoint)), + ) + } + "openai/chat" | "mistral/chat" => { + let config = OpenAIConfig::default() + .with_api_base(api_endpoint) + .with_api_key(model.api_key.clone().unwrap_or_default()); + + let mut builder = ExtendedOpenAIConfig::builder(); + builder + .base(config) + .supported_models(model.supported_models.clone()) + .model_name(model.model_name.as_deref().expect("Model name is required")); + + if model.kind == "mistral/chat" { + builder.fields_to_remove(ExtendedOpenAIConfig::mistral_fields_to_remove()); + } + + Box::new( + async_openai_alt::Client::with_config( + builder.build().expect("Failed to build config"), + ) + .with_http_client(create_reqwest_client(api_endpoint)), + ) + } + _ => panic!("Unsupported model kind: {}", model.kind), + }; Arc::new(rate_limit::new_chat( engine, diff --git a/crates/tabby-common/src/config.rs b/crates/tabby-common/src/config.rs index 019e9233ffab..95b68f1082db 100644 --- a/crates/tabby-common/src/config.rs +++ b/crates/tabby-common/src/config.rs @@ -310,6 +310,11 @@ pub struct HttpModelConfig { #[builder(default)] pub additional_stop_words: Option>, + + /// Used For Azure API to specify the api version + #[builder(default)] + #[serde(default)] + pub api_version: Option, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] diff --git a/crates/tabby-inference/src/chat.rs b/crates/tabby-inference/src/chat.rs index ff3b2d1672d3..2dd97116b293 100644 --- a/crates/tabby-inference/src/chat.rs +++ b/crates/tabby-inference/src/chat.rs @@ -125,3 +125,20 @@ impl ChatCompletionStream for async_openai_alt::Client { self.chat().create_stream(request).await } } + +#[async_trait] +impl ChatCompletionStream for async_openai_alt::Client { + async fn chat( + &self, + request: CreateChatCompletionRequest, + ) -> Result { + self.chat().create(request).await + } + + async fn chat_stream( + &self, + request: CreateChatCompletionRequest, + ) -> Result { + self.chat().create_stream(request).await + } +} From 089dad0a10b580157e6261cc466f85c70b269a60 Mon Sep 17 00:00:00 2001 From: Sma1lboy <541898146chen@gmail.com> Date: Tue, 14 Jan 2025 00:24:27 -0500 Subject: [PATCH 2/6] feat(embedding): add AzureEmbeddingEngine for Azure API integration --- .../http-api-bindings/src/embedding/azure.rs | 87 +++++++++++++++++++ crates/http-api-bindings/src/embedding/mod.rs | 11 +++ 2 files changed, 98 insertions(+) create mode 100644 crates/http-api-bindings/src/embedding/azure.rs diff --git a/crates/http-api-bindings/src/embedding/azure.rs b/crates/http-api-bindings/src/embedding/azure.rs new file mode 100644 index 000000000000..c4c7c3cb3cb4 --- /dev/null +++ b/crates/http-api-bindings/src/embedding/azure.rs @@ -0,0 +1,87 @@ +use anyhow::Result; +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tabby_inference::Embedding; + +#[derive(Clone)] +pub struct AzureEmbeddingEngine { + client: Arc, + api_endpoint: String, + api_key: String, + api_version: String, +} + +#[derive(Debug, Serialize)] +struct EmbeddingRequest { + input: String, +} + +#[derive(Debug, Deserialize)] +struct EmbeddingResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct Data { + embedding: Vec, +} + +impl AzureEmbeddingEngine { + pub fn create( + api_endpoint: &str, + model_name: &str, + api_key: Option<&str>, + api_version: Option<&str>, + ) -> Box { + let client = Client::new(); + let deployment_id = model_name; + let azure_endpoint = format!( + "{}/openai/deployments/{}/embeddings", + api_endpoint.trim_end_matches('/'), + deployment_id + ); + + Box::new(Self { + client: Arc::new(client), + api_endpoint: azure_endpoint, + api_key: api_key.unwrap_or_default().to_owned(), + api_version: api_version.unwrap_or("2024-02-15-preview").to_owned(), + }) + } +} + +#[async_trait] +impl Embedding for AzureEmbeddingEngine { + async fn embed(&self, prompt: &str) -> Result> { + let client = self.client.clone(); + let api_endpoint = self.api_endpoint.clone(); + let api_key = self.api_key.clone(); + let api_version = self.api_version.clone(); + let request = EmbeddingRequest { + input: prompt.to_owned(), + }; + + let response = client + .post(&api_endpoint) + .query(&[("api-version", &api_version)]) + .header("api-key", &api_key) + .header("Content-Type", "application/json") + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + let error_text = response.text().await?; + anyhow::bail!("Azure API error: {}", error_text); + } + + let embedding_response: EmbeddingResponse = response.json().await?; + embedding_response + .data + .first() + .map(|data| data.embedding.clone()) + .ok_or_else(|| anyhow::anyhow!("No embedding data received")) + } +} diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index c68781e3eb6e..dac3f90801b5 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -1,9 +1,11 @@ +mod azure; mod llama; mod openai; use core::panic; use std::sync::Arc; +use azure::AzureEmbeddingEngine; use llama::LlamaCppEngine; use openai::OpenAIEmbeddingEngine; use tabby_common::config::HttpModelConfig; @@ -40,6 +42,15 @@ pub async fn create(config: &HttpModelConfig) -> Arc { .expect("model_name must be set for voyage/embedding"), config.api_key.as_deref(), ), + "azure/embedding" => AzureEmbeddingEngine::create( + config + .api_endpoint + .as_deref() + .expect("api_endpoint is required for azure/embedding"), + config.model_name.as_deref().unwrap_or_default(), // Provide a default if model_name is optional + config.api_key.as_deref(), // Pass the API key if available + config.api_version.as_deref(), // Pass the API version if available + ), unsupported_kind => panic!( "Unsupported kind for http embedding model: {}", unsupported_kind From 5383bfc782cfd38d8f3fef85abd71a97d42a82ad Mon Sep 17 00:00:00 2001 From: Sma1lboy <541898146chen@gmail.com> Date: Tue, 14 Jan 2025 00:46:12 -0500 Subject: [PATCH 3/6] fix(azure): update default API version for Azure integration --- crates/http-api-bindings/src/chat/mod.rs | 7 +------ crates/http-api-bindings/src/embedding/azure.rs | 2 +- crates/http-api-bindings/src/embedding/mod.rs | 4 ++-- crates/tabby-common/src/config.rs | 5 ----- 4 files changed, 4 insertions(+), 14 deletions(-) diff --git a/crates/http-api-bindings/src/chat/mod.rs b/crates/http-api-bindings/src/chat/mod.rs index b0bb1d9aa651..939c775b4369 100644 --- a/crates/http-api-bindings/src/chat/mod.rs +++ b/crates/http-api-bindings/src/chat/mod.rs @@ -18,12 +18,7 @@ pub async fn create(model: &HttpModelConfig) -> Arc { let config = async_openai_alt::config::AzureConfig::new() .with_api_base(api_endpoint) .with_api_key(model.api_key.clone().unwrap_or_default()) - .with_api_version( - model - .api_version - .clone() - .unwrap_or("2024-05-01-preview".to_string()), - ) + .with_api_version("2024-08-01-preview") .with_deployment_id(model.model_name.as_deref().expect("Model name is required")); Box::new( async_openai_alt::Client::with_config(config) diff --git a/crates/http-api-bindings/src/embedding/azure.rs b/crates/http-api-bindings/src/embedding/azure.rs index c4c7c3cb3cb4..1aae67903fee 100644 --- a/crates/http-api-bindings/src/embedding/azure.rs +++ b/crates/http-api-bindings/src/embedding/azure.rs @@ -47,7 +47,7 @@ impl AzureEmbeddingEngine { client: Arc::new(client), api_endpoint: azure_endpoint, api_key: api_key.unwrap_or_default().to_owned(), - api_version: api_version.unwrap_or("2024-02-15-preview").to_owned(), + api_version: api_version.unwrap_or("2023-05-15").to_owned(), }) } } diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index dac3f90801b5..2ac128a9da6a 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -48,8 +48,8 @@ pub async fn create(config: &HttpModelConfig) -> Arc { .as_deref() .expect("api_endpoint is required for azure/embedding"), config.model_name.as_deref().unwrap_or_default(), // Provide a default if model_name is optional - config.api_key.as_deref(), // Pass the API key if available - config.api_version.as_deref(), // Pass the API version if available + config.api_key.as_deref(), + Some("2023-05-15"), ), unsupported_kind => panic!( "Unsupported kind for http embedding model: {}", diff --git a/crates/tabby-common/src/config.rs b/crates/tabby-common/src/config.rs index 95b68f1082db..019e9233ffab 100644 --- a/crates/tabby-common/src/config.rs +++ b/crates/tabby-common/src/config.rs @@ -310,11 +310,6 @@ pub struct HttpModelConfig { #[builder(default)] pub additional_stop_words: Option>, - - /// Used For Azure API to specify the api version - #[builder(default)] - #[serde(default)] - pub api_version: Option, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] From 9e4914576f523b41e490c1088f6f843412bf0c19 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Tue, 14 Jan 2025 05:52:40 +0000 Subject: [PATCH 4/6] [autofix.ci] apply automated fixes --- crates/http-api-bindings/src/embedding/azure.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/http-api-bindings/src/embedding/azure.rs b/crates/http-api-bindings/src/embedding/azure.rs index 1aae67903fee..fa6a958054eb 100644 --- a/crates/http-api-bindings/src/embedding/azure.rs +++ b/crates/http-api-bindings/src/embedding/azure.rs @@ -1,8 +1,9 @@ +use std::sync::Arc; + use anyhow::Result; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; -use std::sync::Arc; use tabby_inference::Embedding; #[derive(Clone)] From 8ed31c6a142a3b54f24c10899d0c978dcf840b0e Mon Sep 17 00:00:00 2001 From: Sma1lboy <541898146chen@gmail.com> Date: Tue, 14 Jan 2025 00:55:23 -0500 Subject: [PATCH 5/6] feat(embedding): enhance AzureEmbeddingEngine with detailed documentation and API version support --- .../http-api-bindings/src/embedding/azure.rs | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/crates/http-api-bindings/src/embedding/azure.rs b/crates/http-api-bindings/src/embedding/azure.rs index fa6a958054eb..cdd816237e49 100644 --- a/crates/http-api-bindings/src/embedding/azure.rs +++ b/crates/http-api-bindings/src/embedding/azure.rs @@ -6,6 +6,9 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use tabby_inference::Embedding; +/// `AzureEmbeddingEngine` is responsible for interacting with Azure's Embedding API. +/// +/// **Note**: Currently, this implementation only supports the OpenAI API and specific API versions. #[derive(Clone)] pub struct AzureEmbeddingEngine { client: Arc, @@ -14,22 +17,39 @@ pub struct AzureEmbeddingEngine { api_version: String, } +/// Structure representing the request body for embedding. #[derive(Debug, Serialize)] struct EmbeddingRequest { input: String, } +/// Structure representing the response from the embedding API. #[derive(Debug, Deserialize)] struct EmbeddingResponse { data: Vec, } +/// Structure representing individual embedding data. #[derive(Debug, Deserialize)] struct Data { embedding: Vec, } impl AzureEmbeddingEngine { + /// Creates a new instance of `AzureEmbeddingEngine`. + /// + /// **Note**: Currently, this implementation only supports the OpenAI API and specific API versions. + /// + /// # Parameters + /// + /// - `api_endpoint`: The base URL of the Azure Embedding API. + /// - `model_name`: The name of the deployed model, used to construct the deployment ID. + /// - `api_key`: Optional API key for authentication. + /// - `api_version`: Optional API version, defaults to "2023-05-15". + /// + /// # Returns + /// + /// A boxed instance that implements the `Embedding` trait. pub fn create( api_endpoint: &str, model_name: &str, @@ -38,6 +58,7 @@ impl AzureEmbeddingEngine { ) -> Box { let client = Client::new(); let deployment_id = model_name; + // Construct the full endpoint URL for the Azure Embedding API let azure_endpoint = format!( "{}/openai/deployments/{}/embeddings", api_endpoint.trim_end_matches('/'), @@ -48,6 +69,7 @@ impl AzureEmbeddingEngine { client: Arc::new(client), api_endpoint: azure_endpoint, api_key: api_key.unwrap_or_default().to_owned(), + // Use a specific API version; currently, only this version is supported api_version: api_version.unwrap_or("2023-05-15").to_owned(), }) } @@ -55,7 +77,19 @@ impl AzureEmbeddingEngine { #[async_trait] impl Embedding for AzureEmbeddingEngine { + /// Generates an embedding vector for the given prompt. + /// + /// **Note**: Currently, this implementation only supports the OpenAI API and specific API versions. + /// + /// # Parameters + /// + /// - `prompt`: The input text to generate embeddings for. + /// + /// # Returns + /// + /// A `Result` containing the embedding vector or an error. async fn embed(&self, prompt: &str) -> Result> { + // Clone all necessary fields to ensure thread safety across await points let client = self.client.clone(); let api_endpoint = self.api_endpoint.clone(); let api_key = self.api_key.clone(); @@ -64,6 +98,7 @@ impl Embedding for AzureEmbeddingEngine { input: prompt.to_owned(), }; + // Send a POST request to the Azure Embedding API let response = client .post(&api_endpoint) .query(&[("api-version", &api_version)]) @@ -73,11 +108,13 @@ impl Embedding for AzureEmbeddingEngine { .send() .await?; + // Check if the response status indicates success if !response.status().is_success() { let error_text = response.text().await?; anyhow::bail!("Azure API error: {}", error_text); } + // Deserialize the response body into `EmbeddingResponse` let embedding_response: EmbeddingResponse = response.json().await?; embedding_response .data From 27f36153c193550856ec831ecfd307c6839f3bef Mon Sep 17 00:00:00 2001 From: Sma1lboy <541898146chen@gmail.com> Date: Tue, 14 Jan 2025 04:03:15 -0500 Subject: [PATCH 6/6] feat(chat): add support for Anthropic chat completion in chat module --- .../http-api-bindings/src/chat/anthropic.rs | 516 ++++++++++++++++++ crates/http-api-bindings/src/chat/mod.rs | 10 + 2 files changed, 526 insertions(+) create mode 100644 crates/http-api-bindings/src/chat/anthropic.rs diff --git a/crates/http-api-bindings/src/chat/anthropic.rs b/crates/http-api-bindings/src/chat/anthropic.rs new file mode 100644 index 000000000000..1809901982ab --- /dev/null +++ b/crates/http-api-bindings/src/chat/anthropic.rs @@ -0,0 +1,516 @@ +use std::sync::Arc; + +use crate::create_reqwest_client; +use async_openai_alt::error::{ApiError, OpenAIError}; +use async_openai_alt::types::{ + ChatChoiceStream, ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage, + ChatCompletionRequestUserMessageContent, ChatCompletionResponseStream, + ChatCompletionStreamResponseDelta, CreateChatCompletionRequest, CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, +}; +use async_trait::async_trait; +use futures::TryStreamExt; +use reqwest::Client; +use serde::de::Error; +use serde::{Deserialize, Serialize}; +use tabby_inference::ChatCompletionStream; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::StreamExt; + +use async_openai_alt::types::{ + ChatChoice as OpenAIChatChoice, + ChatCompletionResponseMessage as OpenAIChatCompletionResponseMessage, + CreateChatCompletionResponse as OpenAICreateChatCompletionResponse, + FinishReason as OpenAIFinishReason, Role as OpenAIRole, +}; + +const HUMAN_PROMPT: &str = "<|Human|>"; +const AI_PROMPT: &str = "<|Assistant|>"; + +#[derive(Debug, Serialize)] +struct AnthropicCreateChatCompletionRequest { + prompt: String, + model: String, + max_tokens: Option, + temperature: Option, + stop_sequences: Option>, + stream: bool, +} + +#[derive(Debug, Deserialize)] +struct AnthropicCreateChatCompletionStreamResponse { + completion: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type", content = "data")] +enum AnthropicEvent { + #[serde(rename = "message_start")] + MessageStart { message: Message }, + + #[serde(rename = "content_block_start")] + ContentBlockStart { + index: u32, + content_block: ContentBlock, + }, + + #[serde(rename = "ping")] + Ping, + + #[serde(rename = "content_block_delta")] + ContentBlockDelta { + index: u32, + delta: ContentBlockDelta, + }, + + #[serde(rename = "content_block_stop")] + ContentBlockStop { index: u32 }, + + #[serde(rename = "message_delta")] + MessageDelta { + delta: MessageDelta, + usage: Option, + }, + + #[serde(rename = "message_stop")] + MessageStop, +} + +#[derive(Debug, Deserialize)] +struct Message { + id: String, + #[serde(rename = "type")] + type_: String, + role: String, + model: String, + content: Vec, + stop_reason: Option, + stop_sequence: Option, + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ContentBlock { + #[serde(rename = "type")] + type_: String, + text: String, +} + +#[derive(Debug, Deserialize)] +struct ContentBlockDelta { + #[serde(rename = "type")] + type_: String, + text: String, +} + +#[derive(Debug, Deserialize)] +struct MessageDelta { + stop_reason: Option, + stop_sequence: Option, +} + +#[derive(Debug, Deserialize)] +struct ContentItem { + #[serde(rename = "type")] + type_: String, + text: String, +} + +#[derive(Debug, Deserialize)] +struct Usage { + input_tokens: u32, + cache_creation_input_tokens: u32, + cache_read_input_tokens: u32, + output_tokens: u32, +} + +#[derive(Debug, Deserialize)] +struct AnthropicCreateChatCompletionResponse { + id: String, + #[serde(rename = "type")] + type_: String, + role: String, + model: String, + content: Vec, + stop_reason: Option, + stop_sequence: Option, + usage: Option, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct CompletionUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct ServiceTierResponse { + pub name: String, + pub resource_group: String, + pub location: String, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct CreateChatCompletionStreamChoice { + pub delta: Option, + pub index: u32, + pub finish_reason: Option, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize, Default)] +pub struct CreateChatCompletionDelta { + pub content: Option, +} + +#[derive(Clone)] +pub struct AnthropicChatCompletion { + client: Arc, + api_endpoint: String, + api_key: String, + default_model: String, +} + +impl AnthropicChatCompletion { + pub fn new(api_endpoint: &str, api_key: &str, default_model: &str) -> Self { + Self { + client: Arc::new(create_reqwest_client(api_endpoint)), + api_endpoint: api_endpoint.trim_end_matches('/').to_string(), + api_key: api_key.to_string(), + default_model: default_model.to_string(), + } + } + + fn completion_url(&self) -> String { + format!("{}/complete", self.api_endpoint) + } + + fn transform_request( + &self, + request: &CreateChatCompletionRequest, + ) -> Result { + let prompt = request + .messages + .iter() + .filter_map(|msg| match msg { + ChatCompletionRequestMessage::User(user_msg) => match &user_msg.content { + ChatCompletionRequestUserMessageContent::Text(text) => Some(text.clone()), + _ => None, + }, + ChatCompletionRequestMessage::Assistant(assistant_msg) => { + match &assistant_msg.content { + Some(ChatCompletionRequestAssistantMessageContent::Text(text)) => { + Some(text.clone()) + } + _ => None, + } + } + _ => None, + }) + .collect::>() + .join("\n"); + + Ok(AnthropicCreateChatCompletionRequest { + prompt: format!("{}{}\n{}", HUMAN_PROMPT, prompt, AI_PROMPT), + model: request.model.clone(), + max_tokens: request.max_tokens, + temperature: request.temperature, + stop_sequences: Some(vec![HUMAN_PROMPT.to_string()]), + stream: false, + }) + } + + fn transform_response( + &self, + anthropic_response: AnthropicCreateChatCompletionResponse, + ) -> Result { + let completion_text = anthropic_response + .content + .iter() + .filter_map(|item| { + if item.type_ == "text" { + Some(item.text.clone()) + } else { + None + } + }) + .collect::>() + .join("\n"); + + if completion_text.is_empty() { + return Err(OpenAIError::ApiError(ApiError { + message: "No completion found".to_string(), + r#type: None, + param: None, + code: None, + })); + } + + Ok(OpenAICreateChatCompletionResponse { + id: anthropic_response.id.clone(), + choices: vec![OpenAIChatChoice { + message: OpenAIChatCompletionResponseMessage { + role: OpenAIRole::Assistant, + content: Some(completion_text), + refusal: None, + tool_calls: None, + function_call: None, + }, + finish_reason: match anthropic_response.stop_reason.as_deref() { + Some("end_turn") => Some(OpenAIFinishReason::Stop), + Some("length") => Some(OpenAIFinishReason::Length), + Some("content_filter") => Some(OpenAIFinishReason::ContentFilter), + _ => None, + }, + index: 0, + logprobs: None, + }], + created: 0, + model: anthropic_response.model.clone(), + service_tier: None, + system_fingerprint: None, + object: "chat.completion".to_string(), + usage: None, + }) + } + + fn transform_stream_chunk( + &self, + anthropic_chunk: AnthropicCreateChatCompletionStreamResponse, + ) -> Option { + anthropic_chunk + .completion + .map(|part| CreateChatCompletionStreamChoice { + delta: Some(CreateChatCompletionDelta { + content: Some(part), + ..Default::default() + }), + index: 0, + finish_reason: None, + }) + } + + fn transform_stream_chunk_static( + anthropic_chunk: &AnthropicCreateChatCompletionStreamResponse, + ) -> Option { + anthropic_chunk + .completion + .as_ref() + .map(|part| CreateChatCompletionStreamChoice { + delta: Some(CreateChatCompletionDelta { + content: Some(part.clone()), + ..Default::default() + }), + index: 0, + finish_reason: None, + }) + } +} + +#[derive(Debug, Deserialize)] +struct WrappedError { + pub error: ApiError, +} + +#[async_trait] +impl ChatCompletionStream for AnthropicChatCompletion { + async fn chat( + &self, + request: CreateChatCompletionRequest, + ) -> Result { + let anthropic_request = self.transform_request(&request)?; + + let response = self + .client + .post(&self.completion_url()) + .header("Content-Type", "application/json") + .header("X-API-Key", &self.api_key) + .json(&anthropic_request) + .send() + .await + .map_err(OpenAIError::Reqwest)?; + + if !response.status().is_success() { + let error_text = response.text().await.unwrap_or_default(); + let api_error = serde_json::from_str::(&error_text).map_err(|_| { + OpenAIError::ApiError(ApiError { + message: error_text.clone(), + r#type: None, + param: None, + code: None, + }) + })?; + return Err(OpenAIError::ApiError(api_error.error)); + } + + let anthropic_response: AnthropicCreateChatCompletionResponse = response + .json() + .await + .map_err(|e| OpenAIError::JSONDeserialize(serde_json::Error::custom(e)))?; + + self.transform_response(anthropic_response) + } + + async fn chat_stream( + &self, + request: CreateChatCompletionRequest, + ) -> Result { + let anthropic_request = self.transform_request(&request)?; + + let anthropic_request_stream = AnthropicCreateChatCompletionRequest { + prompt: format!( + "{}{}\n{}", + HUMAN_PROMPT, anthropic_request.prompt, AI_PROMPT + ), + model: anthropic_request.model, + max_tokens: anthropic_request.max_tokens, + temperature: anthropic_request.temperature, + stop_sequences: anthropic_request.stop_sequences, + stream: true, + }; + + let response = self + .client + .post(&self.completion_url()) + .header("Content-Type", "application/json") + .header("x-api-key", &self.api_key) + .json(&anthropic_request_stream) + .send() + .await + .map_err(OpenAIError::Reqwest)?; + + if !response.status().is_success() { + let error_text = response.text().await.unwrap_or_default(); + let api_error = serde_json::from_str::(&error_text).map_err(|_| { + OpenAIError::ApiError(ApiError { + message: error_text.clone(), + r#type: None, + param: None, + code: None, + }) + })?; + return Err(OpenAIError::ApiError(api_error.error)); + } + + let (tx, rx) = mpsc::channel(100); + + let mut stream = response + .bytes_stream() + .map_err(|e| OpenAIError::StreamError(e.to_string())); + + let tx_clone = tx.clone(); + tokio::spawn(async move { + let mut response_id: Option = None; + let mut created_timestamp: Option = None; + let mut model_name: Option = None; + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + let chunk_str = match String::from_utf8(chunk.to_vec()) { + Ok(s) => s, + Err(_) => continue, + }; + + for line in chunk_str.lines() { + if line.starts_with("event: ") { + } else if line.starts_with("data: ") { + let data_str = line["data: ".len()..].trim(); + if data_str.is_empty() { + continue; + } + + match serde_json::from_str::(data_str) { + Ok(event) => { + match event { + AnthropicEvent::MessageStart { message } => { + response_id = Some(message.id.clone()); + created_timestamp = Some(0); + model_name = Some(message.model.clone()); + } + AnthropicEvent::ContentBlockDelta { delta, .. } => { + let text = delta.text.clone(); + let chat_choice = ChatChoiceStream { + delta: ChatCompletionStreamResponseDelta { + content: Some(text), + role: Some(async_openai_alt::types::Role::Assistant), + refusal: None, + tool_calls: None, + function_call: None, + }, + index: 0, + finish_reason: None, + logprobs: None, + }; + + let response = CreateChatCompletionStreamResponse { + id: response_id.clone().unwrap_or_default(), + choices: vec![chat_choice], + created: created_timestamp.unwrap_or(0), + model: model_name.clone().unwrap_or_default(), + service_tier: None, + system_fingerprint: None, + object: "chat.completion.chunk".to_string(), + usage: None, + }; + + if tx_clone.send(Ok(response)).await.is_err() { + break; + } + } + AnthropicEvent::MessageDelta { delta, usage } => { + let chat_choice = ChatChoiceStream { + delta: ChatCompletionStreamResponseDelta { + content: None, + role: Some(async_openai_alt::types::Role::Assistant), + refusal: None, + tool_calls: None, + function_call: None, + }, + index: 0, + finish_reason: Some(async_openai_alt::types::FinishReason::Stop), + logprobs: None, + }; + + let response = CreateChatCompletionStreamResponse { + id: response_id.clone().unwrap_or_default(), + choices: vec![chat_choice], + created: created_timestamp.unwrap_or(0), + model: model_name.clone().unwrap_or_default(), + service_tier: None, + system_fingerprint: None, + object: "chat.completion.chunk".to_string(), + usage: None, + }; + + if tx_clone.send(Ok(response)).await.is_err() { + break; + } + } + _ => {} + } + } + Err(e) => { + let _ = tx_clone + .send(Err(OpenAIError::JSONDeserialize(e))) + .await; + break; + } + } + } + } + } + Err(e) => { + let _ = tx_clone + .send(Err(OpenAIError::StreamError(e.to_string()))) + .await; + break; + } + } + } + }); + + Ok(Box::pin(ReceiverStream::new(rx))) + } +} diff --git a/crates/http-api-bindings/src/chat/mod.rs b/crates/http-api-bindings/src/chat/mod.rs index 939c775b4369..08b442ec7649 100644 --- a/crates/http-api-bindings/src/chat/mod.rs +++ b/crates/http-api-bindings/src/chat/mod.rs @@ -1,5 +1,7 @@ +mod anthropic; use std::sync::Arc; +use anthropic::AnthropicChatCompletion; use async_openai_alt::config::OpenAIConfig; use tabby_common::config::HttpModelConfig; use tabby_inference::{ChatCompletionStream, ExtendedOpenAIConfig}; @@ -25,6 +27,14 @@ pub async fn create(model: &HttpModelConfig) -> Arc { .with_http_client(create_reqwest_client(api_endpoint)), ) } + "anthropic/chat" => { + let anthropic = AnthropicChatCompletion::new( + api_endpoint, + &model.api_key.clone().unwrap_or_default(), + model.model_name.as_deref().expect("Model name is required"), + ); + Box::new(anthropic) + } "openai/chat" | "mistral/chat" => { let config = OpenAIConfig::default() .with_api_base(api_endpoint)