Skip to content

Commit

Permalink
feat(models-http-api): add rate limit support for completion and chat (
Browse files Browse the repository at this point in the history
…#3468)

* feat(models-http-api): add rate limit support for completion

* feat(models-http-api): add rate limit support for chat

* chore: use one mod for rate limit

* chore: only expose impl for rate limited
  • Loading branch information
zwpaper authored Nov 26, 2024
1 parent 8e3e449 commit 53f4206
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 103 deletions.
11 changes: 9 additions & 2 deletions crates/http-api-bindings/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use async_openai::config::OpenAIConfig;
use tabby_common::config::HttpModelConfig;
use tabby_inference::{ChatCompletionStream, ExtendedOpenAIConfig};

use super::rate_limit;
use crate::create_reqwest_client;

pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
Expand All @@ -16,6 +17,7 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
.with_api_key(model.api_key.clone().unwrap_or_default());

let mut builder = ExtendedOpenAIConfig::builder();

builder
.base(config)
.supported_models(model.supported_models.clone())
Expand All @@ -31,8 +33,13 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {

let config = builder.build().expect("Failed to build config");

Arc::new(
let engine = Box::new(
async_openai::Client::with_config(config)
.with_http_client(create_reqwest_client(api_endpoint)),
)
);

Arc::new(rate_limit::new_chat(
engine,
model.rate_limit.request_per_minute,
))
}
6 changes: 3 additions & 3 deletions crates/http-api-bindings/src/completion/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ pub struct LlamaCppEngine {
}

impl LlamaCppEngine {
pub fn create(api_endpoint: &str, api_key: Option<String>) -> Self {
pub fn create(api_endpoint: &str, api_key: Option<String>) -> Box<dyn CompletionStream> {
let client = create_reqwest_client(api_endpoint);

Self {
Box::new(Self {
client,
api_endpoint: format!("{}/completion", api_endpoint),
api_key,
}
})
}
}

Expand Down
6 changes: 3 additions & 3 deletions crates/http-api-bindings/src/completion/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ impl MistralFIMEngine {
api_endpoint: Option<&str>,
api_key: Option<String>,
model_name: Option<String>,
) -> Self {
) -> Box<dyn CompletionStream> {
let client = reqwest::Client::new();
let model_name = model_name.unwrap_or("codestral-latest".into());
let api_key = api_key.expect("API key is required for mistral/completion");

Self {
Box::new(Self {
client,
model_name,
api_endpoint: format!(
"{}/v1/fim/completions",
api_endpoint.unwrap_or(DEFAULT_API_ENDPOINT)
),
api_key,
}
})
}
}

Expand Down
83 changes: 39 additions & 44 deletions crates/http-api-bindings/src/completion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,56 +10,51 @@ use openai::OpenAICompletionEngine;
use tabby_common::config::HttpModelConfig;
use tabby_inference::CompletionStream;

use super::rate_limit;

pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
match model.kind.as_str() {
"llama.cpp/completion" => {
let engine = LlamaCppEngine::create(
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
);
Arc::new(engine)
}
let engine = match model.kind.as_str() {
"llama.cpp/completion" => LlamaCppEngine::create(
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
),
"ollama/completion" => ollama_api_bindings::create_completion(model).await,
"mistral/completion" => {
let engine = MistralFIMEngine::create(
model.api_endpoint.as_deref(),
model.api_key.clone(),
model.model_name.clone(),
);
Arc::new(engine)
}
x if OPENAI_LEGACY_COMPLETION_FIM_ALIASES.contains(&x) => {
let engine = OpenAICompletionEngine::create(
model.model_name.clone(),
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
true,
);
Arc::new(engine)
}
"openai/legacy_completion_no_fim" | "vllm/completion" => {
let engine = OpenAICompletionEngine::create(
model.model_name.clone(),
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
false,
);
Arc::new(engine)
}
"mistral/completion" => MistralFIMEngine::create(
model.api_endpoint.as_deref(),
model.api_key.clone(),
model.model_name.clone(),
),
x if OPENAI_LEGACY_COMPLETION_FIM_ALIASES.contains(&x) => OpenAICompletionEngine::create(
model.model_name.clone(),
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
true,
),
"openai/legacy_completion_no_fim" | "vllm/completion" => OpenAICompletionEngine::create(
model.model_name.clone(),
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
false,
),
unsupported_kind => panic!(
"Unsupported model kind for http completion: {}",
unsupported_kind
),
}
};

Arc::new(rate_limit::new_completion(
engine,
model.rate_limit.request_per_minute,
))
}

const FIM_TOKEN: &str = "<|FIM|>";
Expand Down
6 changes: 3 additions & 3 deletions crates/http-api-bindings/src/completion/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ impl OpenAICompletionEngine {
api_endpoint: &str,
api_key: Option<String>,
support_fim: bool,
) -> Self {
) -> Box<dyn CompletionStream> {
let model_name = model_name.expect("model_name is required for openai/completion");
let client = reqwest::Client::new();

Self {
Box::new(Self {
client,
model_name,
api_endpoint: format!("{}/completions", api_endpoint),
api_key,
support_fim,
}
})
}
}

Expand Down
16 changes: 5 additions & 11 deletions crates/http-api-bindings/src/embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
mod llama;
mod openai;
mod rate_limit;
mod voyage;

use core::panic;
use std::{sync::Arc, time::Duration};
use std::sync::Arc;

use llama::LlamaCppEngine;
use ratelimit::Ratelimiter;
use tabby_common::config::HttpModelConfig;
use tabby_inference::Embedding;

use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine};
use super::rate_limit;

pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
let engine = match config.kind.as_str() {
Expand Down Expand Up @@ -48,13 +47,8 @@ pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
),
};

let ratelimiter = Ratelimiter::builder(
Arc::new(rate_limit::new_embedding(
engine,
config.rate_limit.request_per_minute,
Duration::from_secs(60),
)
.max_tokens(config.rate_limit.request_per_minute)
.build()
.expect("Failed to create ratelimiter, please check the rate limit configuration");

Arc::new(rate_limit::RateLimitedEmbedding::new(engine, ratelimiter))
))
}
33 changes: 0 additions & 33 deletions crates/http-api-bindings/src/embedding/rate_limit.rs

This file was deleted.

1 change: 1 addition & 0 deletions crates/http-api-bindings/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod chat;
mod completion;
mod embedding;
mod rate_limit;

pub use chat::create as create_chat;
pub use completion::{build_completion_prompt, create};
Expand Down
Loading

0 comments on commit 53f4206

Please sign in to comment.