Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(models-http-api): add rate limit support for completion and chat #3468

Merged
merged 4 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions crates/http-api-bindings/src/chat/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::sync::Arc;
mod rate_limit;

use std::{sync::Arc, time::Duration};

use async_openai::config::OpenAIConfig;
use ratelimit::Ratelimiter;
use tabby_common::config::HttpModelConfig;
use tabby_inference::{ChatCompletionStream, ExtendedOpenAIConfig};

Expand Down Expand Up @@ -31,8 +34,16 @@

let config = builder.build().expect("Failed to build config");

Arc::new(
let engine = Box::new(

Check warning on line 37 in crates/http-api-bindings/src/chat/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/chat/mod.rs#L37

Added line #L37 was not covered by tests
async_openai::Client::with_config(config)
.with_http_client(create_reqwest_client(api_endpoint)),
)
);

let ratelimiter =
zwpaper marked this conversation as resolved.
Show resolved Hide resolved
Ratelimiter::builder(model.rate_limit.request_per_minute, Duration::from_secs(60))
.max_tokens(model.rate_limit.request_per_minute)
.build()
.expect("Failed to create ratelimiter, please check the rate limit configuration");

Arc::new(rate_limit::RateLimitedChatStream::new(engine, ratelimiter))

Check warning on line 48 in crates/http-api-bindings/src/chat/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/chat/mod.rs#L40-L48

Added lines #L40 - L48 were not covered by tests
}
64 changes: 64 additions & 0 deletions crates/http-api-bindings/src/chat/rate_limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use async_openai::{
error::{ApiError, OpenAIError},
types::{
ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse,
},
};
use async_trait::async_trait;
use ratelimit::Ratelimiter;
use tabby_inference::ChatCompletionStream;

pub struct RateLimitedChatStream {
completion: Box<dyn ChatCompletionStream>,
rate_limiter: Ratelimiter,
}

impl RateLimitedChatStream {
pub fn new(completion: Box<dyn ChatCompletionStream>, rate_limiter: Ratelimiter) -> Self {
Self {
completion,
rate_limiter,
}
}

Check warning on line 22 in crates/http-api-bindings/src/chat/rate_limit.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/chat/rate_limit.rs#L17-L22

Added lines #L17 - L22 were not covered by tests
}

#[async_trait]
impl ChatCompletionStream for RateLimitedChatStream {
async fn chat(
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
for _ in 0..5 {
if let Err(sleep) = self.rate_limiter.try_wait() {
tokio::time::sleep(sleep).await;
continue;
}

return self.completion.chat(request).await;

Check warning on line 37 in crates/http-api-bindings/src/chat/rate_limit.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/chat/rate_limit.rs#L30-L37

Added lines #L30 - L37 were not covered by tests
}

Err(OpenAIError::ApiError(ApiError {
message: "Rate limit exceeded for chat completion".to_owned(),
r#type: None,
param: None,
code: None,
}))
}

Check warning on line 46 in crates/http-api-bindings/src/chat/rate_limit.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/chat/rate_limit.rs#L40-L46

Added lines #L40 - L46 were not covered by tests

async fn chat_stream(
&self,
request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError> {
for _ in 0..5 {
if let Err(sleep) = self.rate_limiter.try_wait() {
tokio::time::sleep(sleep).await;
continue;
}

return self.completion.chat_stream(request).await;

Check warning on line 58 in crates/http-api-bindings/src/chat/rate_limit.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/chat/rate_limit.rs#L51-L58

Added lines #L51 - L58 were not covered by tests
}

// Return an empty stream if the rate limit is exceeded
Ok(Box::pin(futures::stream::empty()))
zwpaper marked this conversation as resolved.
Show resolved Hide resolved
}

Check warning on line 63 in crates/http-api-bindings/src/chat/rate_limit.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/chat/rate_limit.rs#L62-L63

Added lines #L62 - L63 were not covered by tests
}
6 changes: 3 additions & 3 deletions crates/http-api-bindings/src/completion/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
}

impl LlamaCppEngine {
pub fn create(api_endpoint: &str, api_key: Option<String>) -> Self {
pub fn create(api_endpoint: &str, api_key: Option<String>) -> Box<dyn CompletionStream> {

Check warning on line 17 in crates/http-api-bindings/src/completion/llama.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/llama.rs#L17

Added line #L17 was not covered by tests
let client = create_reqwest_client(api_endpoint);

Self {
Box::new(Self {

Check warning on line 20 in crates/http-api-bindings/src/completion/llama.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/llama.rs#L20

Added line #L20 was not covered by tests
client,
api_endpoint: format!("{}/completion", api_endpoint),
api_key,
}
})

Check warning on line 24 in crates/http-api-bindings/src/completion/llama.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/llama.rs#L24

Added line #L24 was not covered by tests
}
}

Expand Down
6 changes: 3 additions & 3 deletions crates/http-api-bindings/src/completion/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@
api_endpoint: Option<&str>,
api_key: Option<String>,
model_name: Option<String>,
) -> Self {
) -> Box<dyn CompletionStream> {

Check warning on line 24 in crates/http-api-bindings/src/completion/mistral.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mistral.rs#L24

Added line #L24 was not covered by tests
let client = reqwest::Client::new();
let model_name = model_name.unwrap_or("codestral-latest".into());
let api_key = api_key.expect("API key is required for mistral/completion");

Self {
Box::new(Self {

Check warning on line 29 in crates/http-api-bindings/src/completion/mistral.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mistral.rs#L29

Added line #L29 was not covered by tests
client,
model_name,
api_endpoint: format!(
"{}/v1/fim/completions",
api_endpoint.unwrap_or(DEFAULT_API_ENDPOINT)
),
api_key,
}
})

Check warning on line 37 in crates/http-api-bindings/src/completion/mistral.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mistral.rs#L37

Added line #L37 was not covered by tests
}
}

Expand Down
88 changes: 43 additions & 45 deletions crates/http-api-bindings/src/completion/mod.rs
Original file line number Diff line number Diff line change
@@ -1,65 +1,63 @@
mod llama;
mod mistral;
mod openai;
mod rate_limit;

use std::sync::Arc;
use std::{sync::Arc, time::Duration};

use llama::LlamaCppEngine;
use mistral::MistralFIMEngine;
use openai::OpenAICompletionEngine;
use ratelimit::Ratelimiter;
use tabby_common::config::HttpModelConfig;
use tabby_inference::CompletionStream;

pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
match model.kind.as_str() {
"llama.cpp/completion" => {
let engine = LlamaCppEngine::create(
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
);
Arc::new(engine)
}
let engine = match model.kind.as_str() {
"llama.cpp/completion" => LlamaCppEngine::create(
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
),

Check warning on line 23 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L16-L23

Added lines #L16 - L23 were not covered by tests
"ollama/completion" => ollama_api_bindings::create_completion(model).await,
"mistral/completion" => {
let engine = MistralFIMEngine::create(
model.api_endpoint.as_deref(),
model.api_key.clone(),
model.model_name.clone(),
);
Arc::new(engine)
}
x if OPENAI_LEGACY_COMPLETION_FIM_ALIASES.contains(&x) => {
let engine = OpenAICompletionEngine::create(
model.model_name.clone(),
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
true,
);
Arc::new(engine)
}
"openai/legacy_completion_no_fim" | "vllm/completion" => {
let engine = OpenAICompletionEngine::create(
model.model_name.clone(),
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
false,
);
Arc::new(engine)
}
"mistral/completion" => MistralFIMEngine::create(
model.api_endpoint.as_deref(),
model.api_key.clone(),
model.model_name.clone(),
),
x if OPENAI_LEGACY_COMPLETION_FIM_ALIASES.contains(&x) => OpenAICompletionEngine::create(
model.model_name.clone(),
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
true,
),
"openai/legacy_completion_no_fim" | "vllm/completion" => OpenAICompletionEngine::create(
model.model_name.clone(),
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
false,
),

Check warning on line 47 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L25-L47

Added lines #L25 - L47 were not covered by tests
unsupported_kind => panic!(
"Unsupported model kind for http completion: {}",
unsupported_kind
),
}
};

let ratelimiter =
zwpaper marked this conversation as resolved.
Show resolved Hide resolved
Ratelimiter::builder(model.rate_limit.request_per_minute, Duration::from_secs(60))
.max_tokens(model.rate_limit.request_per_minute)
.build()
.expect("Failed to create ratelimiter, please check the rate limit configuration");

Arc::new(rate_limit::RateLimitedCompletion::new(engine, ratelimiter))

Check warning on line 60 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L54-L60

Added lines #L54 - L60 were not covered by tests
}

const FIM_TOKEN: &str = "<|FIM|>";
Expand Down
6 changes: 3 additions & 3 deletions crates/http-api-bindings/src/completion/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@
api_endpoint: &str,
api_key: Option<String>,
support_fim: bool,
) -> Self {
) -> Box<dyn CompletionStream> {

Check warning on line 28 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L28

Added line #L28 was not covered by tests
let model_name = model_name.expect("model_name is required for openai/completion");
let client = reqwest::Client::new();

Self {
Box::new(Self {

Check warning on line 32 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L32

Added line #L32 was not covered by tests
client,
model_name,
api_endpoint: format!("{}/completions", api_endpoint),
api_key,
support_fim,
}
})

Check warning on line 38 in crates/http-api-bindings/src/completion/openai.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/openai.rs#L38

Added line #L38 was not covered by tests
}
}

Expand Down
35 changes: 35 additions & 0 deletions crates/http-api-bindings/src/completion/rate_limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use async_trait::async_trait;
use futures::stream::BoxStream;
use ratelimit::Ratelimiter;
use tabby_inference::{CompletionOptions, CompletionStream};

pub struct RateLimitedCompletion {
completion: Box<dyn CompletionStream>,
rate_limiter: Ratelimiter,
}

impl RateLimitedCompletion {
pub fn new(completion: Box<dyn CompletionStream>, rate_limiter: Ratelimiter) -> Self {
Self {
completion,
rate_limiter,
}
}

Check warning on line 17 in crates/http-api-bindings/src/completion/rate_limit.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/rate_limit.rs#L12-L17

Added lines #L12 - L17 were not covered by tests
}

#[async_trait]
impl CompletionStream for RateLimitedCompletion {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
for _ in 0..5 {
if let Err(sleep) = self.rate_limiter.try_wait() {
tokio::time::sleep(sleep).await;
continue;
}

return self.completion.generate(prompt, options).await;

Check warning on line 29 in crates/http-api-bindings/src/completion/rate_limit.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/rate_limit.rs#L22-L29

Added lines #L22 - L29 were not covered by tests
}

// Return an empty stream if the rate limit is exceeded
Box::pin(futures::stream::empty())
}

Check warning on line 34 in crates/http-api-bindings/src/completion/rate_limit.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/rate_limit.rs#L33-L34

Added lines #L33 - L34 were not covered by tests
}
6 changes: 2 additions & 4 deletions crates/ollama-api-bindings/src/completion.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::sync::Arc;

use async_stream::stream;
use async_trait::async_trait;
use futures::{stream::BoxStream, StreamExt};
Expand Down Expand Up @@ -58,11 +56,11 @@
}
}

pub async fn create(config: &HttpModelConfig) -> Arc<dyn CompletionStream> {
pub async fn create(config: &HttpModelConfig) -> Box<dyn CompletionStream> {

Check warning on line 59 in crates/ollama-api-bindings/src/completion.rs

View check run for this annotation

Codecov / codecov/patch

crates/ollama-api-bindings/src/completion.rs#L59

Added line #L59 was not covered by tests
let connection = Ollama::try_new(config.api_endpoint.as_deref().unwrap().to_owned())
.expect("Failed to create connection to Ollama, URL invalid");

let model = connection.select_model_or_default(config).await.unwrap();

Arc::new(OllamaCompletion { connection, model })
Box::new(OllamaCompletion { connection, model })

Check warning on line 65 in crates/ollama-api-bindings/src/completion.rs

View check run for this annotation

Codecov / codecov/patch

crates/ollama-api-bindings/src/completion.rs#L65

Added line #L65 was not covered by tests
}
Loading