From dcba7915606ad52ee4900a65aa50fba67281d42d Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Tue, 26 Nov 2024 12:21:28 +0800 Subject: [PATCH 1/4] feat(models-http-api): add rate limit support for completion --- .../http-api-bindings/src/completion/llama.rs | 6 +- .../src/completion/mistral.rs | 6 +- .../http-api-bindings/src/completion/mod.rs | 88 +++++++++---------- .../src/completion/openai.rs | 6 +- .../src/completion/rate_limit.rs | 35 ++++++++ crates/ollama-api-bindings/src/completion.rs | 6 +- 6 files changed, 89 insertions(+), 58 deletions(-) create mode 100644 crates/http-api-bindings/src/completion/rate_limit.rs diff --git a/crates/http-api-bindings/src/completion/llama.rs b/crates/http-api-bindings/src/completion/llama.rs index 62174fd2c591..45858d9fb240 100644 --- a/crates/http-api-bindings/src/completion/llama.rs +++ b/crates/http-api-bindings/src/completion/llama.rs @@ -14,14 +14,14 @@ pub struct LlamaCppEngine { } impl LlamaCppEngine { - pub fn create(api_endpoint: &str, api_key: Option) -> Self { + pub fn create(api_endpoint: &str, api_key: Option) -> Box { let client = create_reqwest_client(api_endpoint); - Self { + Box::new(Self { client, api_endpoint: format!("{}/completion", api_endpoint), api_key, - } + }) } } diff --git a/crates/http-api-bindings/src/completion/mistral.rs b/crates/http-api-bindings/src/completion/mistral.rs index fef89f49606e..f19ec4c3a3d0 100644 --- a/crates/http-api-bindings/src/completion/mistral.rs +++ b/crates/http-api-bindings/src/completion/mistral.rs @@ -21,12 +21,12 @@ impl MistralFIMEngine { api_endpoint: Option<&str>, api_key: Option, model_name: Option, - ) -> Self { + ) -> Box { let client = reqwest::Client::new(); let model_name = model_name.unwrap_or("codestral-latest".into()); let api_key = api_key.expect("API key is required for mistral/completion"); - Self { + Box::new(Self { client, model_name, api_endpoint: format!( @@ -34,7 +34,7 @@ impl MistralFIMEngine { api_endpoint.unwrap_or(DEFAULT_API_ENDPOINT) ), api_key, - } + }) } } diff --git a/crates/http-api-bindings/src/completion/mod.rs b/crates/http-api-bindings/src/completion/mod.rs index 18764d2bbbbf..9d7ceb3360ec 100644 --- a/crates/http-api-bindings/src/completion/mod.rs +++ b/crates/http-api-bindings/src/completion/mod.rs @@ -1,65 +1,63 @@ mod llama; mod mistral; mod openai; +mod rate_limit; -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use llama::LlamaCppEngine; use mistral::MistralFIMEngine; use openai::OpenAICompletionEngine; +use ratelimit::Ratelimiter; use tabby_common::config::HttpModelConfig; use tabby_inference::CompletionStream; pub async fn create(model: &HttpModelConfig) -> Arc { - match model.kind.as_str() { - "llama.cpp/completion" => { - let engine = LlamaCppEngine::create( - model - .api_endpoint - .as_deref() - .expect("api_endpoint is required"), - model.api_key.clone(), - ); - Arc::new(engine) - } + let engine = match model.kind.as_str() { + "llama.cpp/completion" => LlamaCppEngine::create( + model + .api_endpoint + .as_deref() + .expect("api_endpoint is required"), + model.api_key.clone(), + ), "ollama/completion" => ollama_api_bindings::create_completion(model).await, - "mistral/completion" => { - let engine = MistralFIMEngine::create( - model.api_endpoint.as_deref(), - model.api_key.clone(), - model.model_name.clone(), - ); - Arc::new(engine) - } - x if OPENAI_LEGACY_COMPLETION_FIM_ALIASES.contains(&x) => { - let engine = OpenAICompletionEngine::create( - model.model_name.clone(), - model - .api_endpoint - .as_deref() - .expect("api_endpoint is required"), - model.api_key.clone(), - true, - ); - Arc::new(engine) - } - "openai/legacy_completion_no_fim" | "vllm/completion" => { - let engine = OpenAICompletionEngine::create( - model.model_name.clone(), - model - .api_endpoint - .as_deref() - .expect("api_endpoint is required"), - model.api_key.clone(), - false, - ); - Arc::new(engine) - } + "mistral/completion" => MistralFIMEngine::create( + model.api_endpoint.as_deref(), + model.api_key.clone(), + model.model_name.clone(), + ), + x if OPENAI_LEGACY_COMPLETION_FIM_ALIASES.contains(&x) => OpenAICompletionEngine::create( + model.model_name.clone(), + model + .api_endpoint + .as_deref() + .expect("api_endpoint is required"), + model.api_key.clone(), + true, + ), + "openai/legacy_completion_no_fim" | "vllm/completion" => OpenAICompletionEngine::create( + model.model_name.clone(), + model + .api_endpoint + .as_deref() + .expect("api_endpoint is required"), + model.api_key.clone(), + false, + ), unsupported_kind => panic!( "Unsupported model kind for http completion: {}", unsupported_kind ), - } + }; + + let ratelimiter = + Ratelimiter::builder(model.rate_limit.request_per_minute, Duration::from_secs(60)) + .max_tokens(model.rate_limit.request_per_minute) + .build() + .expect("Failed to create ratelimiter, please check the rate limit configuration"); + + Arc::new(rate_limit::RateLimitedCompletion::new(engine, ratelimiter)) } const FIM_TOKEN: &str = "<|FIM|>"; diff --git a/crates/http-api-bindings/src/completion/openai.rs b/crates/http-api-bindings/src/completion/openai.rs index d6e8fa572b86..6f1bfd024183 100644 --- a/crates/http-api-bindings/src/completion/openai.rs +++ b/crates/http-api-bindings/src/completion/openai.rs @@ -25,17 +25,17 @@ impl OpenAICompletionEngine { api_endpoint: &str, api_key: Option, support_fim: bool, - ) -> Self { + ) -> Box { let model_name = model_name.expect("model_name is required for openai/completion"); let client = reqwest::Client::new(); - Self { + Box::new(Self { client, model_name, api_endpoint: format!("{}/completions", api_endpoint), api_key, support_fim, - } + }) } } diff --git a/crates/http-api-bindings/src/completion/rate_limit.rs b/crates/http-api-bindings/src/completion/rate_limit.rs new file mode 100644 index 000000000000..d9d87d25d961 --- /dev/null +++ b/crates/http-api-bindings/src/completion/rate_limit.rs @@ -0,0 +1,35 @@ +use async_trait::async_trait; +use futures::stream::BoxStream; +use ratelimit::Ratelimiter; +use tabby_inference::{CompletionOptions, CompletionStream}; + +pub struct RateLimitedCompletion { + completion: Box, + rate_limiter: Ratelimiter, +} + +impl RateLimitedCompletion { + pub fn new(completion: Box, rate_limiter: Ratelimiter) -> Self { + Self { + completion, + rate_limiter, + } + } +} + +#[async_trait] +impl CompletionStream for RateLimitedCompletion { + async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream { + for _ in 0..5 { + if let Err(sleep) = self.rate_limiter.try_wait() { + tokio::time::sleep(sleep).await; + continue; + } + + return self.completion.generate(prompt, options).await; + } + + // Return an empty stream if the rate limit is exceeded + Box::pin(futures::stream::empty()) + } +} diff --git a/crates/ollama-api-bindings/src/completion.rs b/crates/ollama-api-bindings/src/completion.rs index fe405e5c668c..fbaec8da9f45 100644 --- a/crates/ollama-api-bindings/src/completion.rs +++ b/crates/ollama-api-bindings/src/completion.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use async_stream::stream; use async_trait::async_trait; use futures::{stream::BoxStream, StreamExt}; @@ -58,11 +56,11 @@ impl CompletionStream for OllamaCompletion { } } -pub async fn create(config: &HttpModelConfig) -> Arc { +pub async fn create(config: &HttpModelConfig) -> Box { let connection = Ollama::try_new(config.api_endpoint.as_deref().unwrap().to_owned()) .expect("Failed to create connection to Ollama, URL invalid"); let model = connection.select_model_or_default(config).await.unwrap(); - Arc::new(OllamaCompletion { connection, model }) + Box::new(OllamaCompletion { connection, model }) } From b6ec3a97f2d0fc105847201f12c5acdac09641a4 Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Tue, 26 Nov 2024 12:36:00 +0800 Subject: [PATCH 2/4] feat(models-http-api): add rate limit support for chat --- crates/http-api-bindings/src/chat/mod.rs | 17 ++++- .../http-api-bindings/src/chat/rate_limit.rs | 64 +++++++++++++++++++ 2 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 crates/http-api-bindings/src/chat/rate_limit.rs diff --git a/crates/http-api-bindings/src/chat/mod.rs b/crates/http-api-bindings/src/chat/mod.rs index b78bdec2c060..1a476d6f00c1 100644 --- a/crates/http-api-bindings/src/chat/mod.rs +++ b/crates/http-api-bindings/src/chat/mod.rs @@ -1,6 +1,9 @@ -use std::sync::Arc; +mod rate_limit; + +use std::{sync::Arc, time::Duration}; use async_openai::config::OpenAIConfig; +use ratelimit::Ratelimiter; use tabby_common::config::HttpModelConfig; use tabby_inference::{ChatCompletionStream, ExtendedOpenAIConfig}; @@ -31,8 +34,16 @@ pub async fn create(model: &HttpModelConfig) -> Arc { let config = builder.build().expect("Failed to build config"); - Arc::new( + let engine = Box::new( async_openai::Client::with_config(config) .with_http_client(create_reqwest_client(api_endpoint)), - ) + ); + + let ratelimiter = + Ratelimiter::builder(model.rate_limit.request_per_minute, Duration::from_secs(60)) + .max_tokens(model.rate_limit.request_per_minute) + .build() + .expect("Failed to create ratelimiter, please check the rate limit configuration"); + + Arc::new(rate_limit::RateLimitedChatStream::new(engine, ratelimiter)) } diff --git a/crates/http-api-bindings/src/chat/rate_limit.rs b/crates/http-api-bindings/src/chat/rate_limit.rs new file mode 100644 index 000000000000..0324e186f572 --- /dev/null +++ b/crates/http-api-bindings/src/chat/rate_limit.rs @@ -0,0 +1,64 @@ +use async_openai::{ + error::{ApiError, OpenAIError}, + types::{ + ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse, + }, +}; +use async_trait::async_trait; +use ratelimit::Ratelimiter; +use tabby_inference::ChatCompletionStream; + +pub struct RateLimitedChatStream { + completion: Box, + rate_limiter: Ratelimiter, +} + +impl RateLimitedChatStream { + pub fn new(completion: Box, rate_limiter: Ratelimiter) -> Self { + Self { + completion, + rate_limiter, + } + } +} + +#[async_trait] +impl ChatCompletionStream for RateLimitedChatStream { + async fn chat( + &self, + request: CreateChatCompletionRequest, + ) -> Result { + for _ in 0..5 { + if let Err(sleep) = self.rate_limiter.try_wait() { + tokio::time::sleep(sleep).await; + continue; + } + + return self.completion.chat(request).await; + } + + Err(OpenAIError::ApiError(ApiError { + message: "Rate limit exceeded for chat completion".to_owned(), + r#type: None, + param: None, + code: None, + })) + } + + async fn chat_stream( + &self, + request: CreateChatCompletionRequest, + ) -> Result { + for _ in 0..5 { + if let Err(sleep) = self.rate_limiter.try_wait() { + tokio::time::sleep(sleep).await; + continue; + } + + return self.completion.chat_stream(request).await; + } + + // Return an empty stream if the rate limit is exceeded + Ok(Box::pin(futures::stream::empty())) + } +} From 2abb6cf52b325d02bf19a4fd839547dd19d71b0b Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Tue, 26 Nov 2024 13:09:38 +0800 Subject: [PATCH 3/4] chore: use one mod for rate limit --- crates/http-api-bindings/src/chat/mod.rs | 18 +-- .../http-api-bindings/src/chat/rate_limit.rs | 64 -------- .../http-api-bindings/src/completion/mod.rs | 17 +- .../src/completion/rate_limit.rs | 35 ----- crates/http-api-bindings/src/embedding/mod.rs | 16 +- .../src/embedding/rate_limit.rs | 33 ---- crates/http-api-bindings/src/lib.rs | 1 + crates/http-api-bindings/src/rate_limit.rs | 145 ++++++++++++++++++ 8 files changed, 165 insertions(+), 164 deletions(-) delete mode 100644 crates/http-api-bindings/src/chat/rate_limit.rs delete mode 100644 crates/http-api-bindings/src/completion/rate_limit.rs delete mode 100644 crates/http-api-bindings/src/embedding/rate_limit.rs create mode 100644 crates/http-api-bindings/src/rate_limit.rs diff --git a/crates/http-api-bindings/src/chat/mod.rs b/crates/http-api-bindings/src/chat/mod.rs index 1a476d6f00c1..604ed81fd476 100644 --- a/crates/http-api-bindings/src/chat/mod.rs +++ b/crates/http-api-bindings/src/chat/mod.rs @@ -1,12 +1,10 @@ -mod rate_limit; - -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; use async_openai::config::OpenAIConfig; -use ratelimit::Ratelimiter; use tabby_common::config::HttpModelConfig; use tabby_inference::{ChatCompletionStream, ExtendedOpenAIConfig}; +use super::rate_limit; use crate::create_reqwest_client; pub async fn create(model: &HttpModelConfig) -> Arc { @@ -19,6 +17,7 @@ pub async fn create(model: &HttpModelConfig) -> Arc { .with_api_key(model.api_key.clone().unwrap_or_default()); let mut builder = ExtendedOpenAIConfig::builder(); + builder .base(config) .supported_models(model.supported_models.clone()) @@ -39,11 +38,8 @@ pub async fn create(model: &HttpModelConfig) -> Arc { .with_http_client(create_reqwest_client(api_endpoint)), ); - let ratelimiter = - Ratelimiter::builder(model.rate_limit.request_per_minute, Duration::from_secs(60)) - .max_tokens(model.rate_limit.request_per_minute) - .build() - .expect("Failed to create ratelimiter, please check the rate limit configuration"); - - Arc::new(rate_limit::RateLimitedChatStream::new(engine, ratelimiter)) + Arc::new(rate_limit::RateLimitedChatStream::new( + engine, + model.rate_limit.request_per_minute, + )) } diff --git a/crates/http-api-bindings/src/chat/rate_limit.rs b/crates/http-api-bindings/src/chat/rate_limit.rs deleted file mode 100644 index 0324e186f572..000000000000 --- a/crates/http-api-bindings/src/chat/rate_limit.rs +++ /dev/null @@ -1,64 +0,0 @@ -use async_openai::{ - error::{ApiError, OpenAIError}, - types::{ - ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse, - }, -}; -use async_trait::async_trait; -use ratelimit::Ratelimiter; -use tabby_inference::ChatCompletionStream; - -pub struct RateLimitedChatStream { - completion: Box, - rate_limiter: Ratelimiter, -} - -impl RateLimitedChatStream { - pub fn new(completion: Box, rate_limiter: Ratelimiter) -> Self { - Self { - completion, - rate_limiter, - } - } -} - -#[async_trait] -impl ChatCompletionStream for RateLimitedChatStream { - async fn chat( - &self, - request: CreateChatCompletionRequest, - ) -> Result { - for _ in 0..5 { - if let Err(sleep) = self.rate_limiter.try_wait() { - tokio::time::sleep(sleep).await; - continue; - } - - return self.completion.chat(request).await; - } - - Err(OpenAIError::ApiError(ApiError { - message: "Rate limit exceeded for chat completion".to_owned(), - r#type: None, - param: None, - code: None, - })) - } - - async fn chat_stream( - &self, - request: CreateChatCompletionRequest, - ) -> Result { - for _ in 0..5 { - if let Err(sleep) = self.rate_limiter.try_wait() { - tokio::time::sleep(sleep).await; - continue; - } - - return self.completion.chat_stream(request).await; - } - - // Return an empty stream if the rate limit is exceeded - Ok(Box::pin(futures::stream::empty())) - } -} diff --git a/crates/http-api-bindings/src/completion/mod.rs b/crates/http-api-bindings/src/completion/mod.rs index 9d7ceb3360ec..d98caaef1100 100644 --- a/crates/http-api-bindings/src/completion/mod.rs +++ b/crates/http-api-bindings/src/completion/mod.rs @@ -1,17 +1,17 @@ mod llama; mod mistral; mod openai; -mod rate_limit; -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; use llama::LlamaCppEngine; use mistral::MistralFIMEngine; use openai::OpenAICompletionEngine; -use ratelimit::Ratelimiter; use tabby_common::config::HttpModelConfig; use tabby_inference::CompletionStream; +use super::rate_limit; + pub async fn create(model: &HttpModelConfig) -> Arc { let engine = match model.kind.as_str() { "llama.cpp/completion" => LlamaCppEngine::create( @@ -51,13 +51,10 @@ pub async fn create(model: &HttpModelConfig) -> Arc { ), }; - let ratelimiter = - Ratelimiter::builder(model.rate_limit.request_per_minute, Duration::from_secs(60)) - .max_tokens(model.rate_limit.request_per_minute) - .build() - .expect("Failed to create ratelimiter, please check the rate limit configuration"); - - Arc::new(rate_limit::RateLimitedCompletion::new(engine, ratelimiter)) + Arc::new(rate_limit::RateLimitedCompletion::new( + engine, + model.rate_limit.request_per_minute, + )) } const FIM_TOKEN: &str = "<|FIM|>"; diff --git a/crates/http-api-bindings/src/completion/rate_limit.rs b/crates/http-api-bindings/src/completion/rate_limit.rs deleted file mode 100644 index d9d87d25d961..000000000000 --- a/crates/http-api-bindings/src/completion/rate_limit.rs +++ /dev/null @@ -1,35 +0,0 @@ -use async_trait::async_trait; -use futures::stream::BoxStream; -use ratelimit::Ratelimiter; -use tabby_inference::{CompletionOptions, CompletionStream}; - -pub struct RateLimitedCompletion { - completion: Box, - rate_limiter: Ratelimiter, -} - -impl RateLimitedCompletion { - pub fn new(completion: Box, rate_limiter: Ratelimiter) -> Self { - Self { - completion, - rate_limiter, - } - } -} - -#[async_trait] -impl CompletionStream for RateLimitedCompletion { - async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream { - for _ in 0..5 { - if let Err(sleep) = self.rate_limiter.try_wait() { - tokio::time::sleep(sleep).await; - continue; - } - - return self.completion.generate(prompt, options).await; - } - - // Return an empty stream if the rate limit is exceeded - Box::pin(futures::stream::empty()) - } -} diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index 6d6bfaeca6ec..8e19d22e99fd 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -1,17 +1,16 @@ mod llama; mod openai; -mod rate_limit; mod voyage; use core::panic; -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; use llama::LlamaCppEngine; -use ratelimit::Ratelimiter; use tabby_common::config::HttpModelConfig; use tabby_inference::Embedding; use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine}; +use super::rate_limit; pub async fn create(config: &HttpModelConfig) -> Arc { let engine = match config.kind.as_str() { @@ -48,13 +47,8 @@ pub async fn create(config: &HttpModelConfig) -> Arc { ), }; - let ratelimiter = Ratelimiter::builder( + Arc::new(rate_limit::RateLimitedEmbedding::new( + engine, config.rate_limit.request_per_minute, - Duration::from_secs(60), - ) - .max_tokens(config.rate_limit.request_per_minute) - .build() - .expect("Failed to create ratelimiter, please check the rate limit configuration"); - - Arc::new(rate_limit::RateLimitedEmbedding::new(engine, ratelimiter)) + )) } diff --git a/crates/http-api-bindings/src/embedding/rate_limit.rs b/crates/http-api-bindings/src/embedding/rate_limit.rs deleted file mode 100644 index 9084eee69dc5..000000000000 --- a/crates/http-api-bindings/src/embedding/rate_limit.rs +++ /dev/null @@ -1,33 +0,0 @@ -use async_trait::async_trait; -use ratelimit::Ratelimiter; -use tabby_inference::Embedding; - -pub struct RateLimitedEmbedding { - embedding: Box, - rate_limiter: Ratelimiter, -} - -impl RateLimitedEmbedding { - pub fn new(embedding: Box, rate_limiter: Ratelimiter) -> Self { - Self { - embedding, - rate_limiter, - } - } -} - -#[async_trait] -impl Embedding for RateLimitedEmbedding { - async fn embed(&self, prompt: &str) -> anyhow::Result> { - for _ in 0..5 { - if let Err(sleep) = self.rate_limiter.try_wait() { - tokio::time::sleep(sleep).await; - continue; - } - - return self.embedding.embed(prompt).await; - } - - anyhow::bail!("Rate limit exceeded for embedding computation"); - } -} diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs index 41e7811421be..bdbdce20d417 100644 --- a/crates/http-api-bindings/src/lib.rs +++ b/crates/http-api-bindings/src/lib.rs @@ -1,6 +1,7 @@ mod chat; mod completion; mod embedding; +mod rate_limit; pub use chat::create as create_chat; pub use completion::{build_completion_prompt, create}; diff --git a/crates/http-api-bindings/src/rate_limit.rs b/crates/http-api-bindings/src/rate_limit.rs new file mode 100644 index 000000000000..30d009b805a7 --- /dev/null +++ b/crates/http-api-bindings/src/rate_limit.rs @@ -0,0 +1,145 @@ +use std::time::Duration; + +use async_openai::{ + error::{ApiError, OpenAIError}, + types::{ + ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse, + }, +}; +use async_trait::async_trait; +use futures::stream::BoxStream; +use ratelimit::Ratelimiter; +use tabby_inference::{ChatCompletionStream, CompletionOptions, CompletionStream, Embedding}; + +fn new_rate_limiter(rpm: u64) -> anyhow::Result { + Ratelimiter::builder(rpm, Duration::from_secs(60)) + .max_tokens(rpm) + .initial_available(rpm) + .build() + .map_err(|e| { + anyhow::anyhow!( + "Failed to create ratelimiter, please check the rate limit configuration: {}", + e, + ) + }) +} + +pub struct RateLimitedEmbedding { + embedding: Box, + rate_limiter: Ratelimiter, +} + +impl RateLimitedEmbedding { + pub fn new(embedding: Box, rpm: u64) -> Self { + Self { + embedding, + rate_limiter: new_rate_limiter(rpm).unwrap(), + } + } +} + +#[async_trait] +impl Embedding for RateLimitedEmbedding { + async fn embed(&self, prompt: &str) -> anyhow::Result> { + for _ in 0..5 { + if let Err(sleep) = self.rate_limiter.try_wait() { + tokio::time::sleep(sleep).await; + continue; + } + + return self.embedding.embed(prompt).await; + } + + anyhow::bail!("Rate limit exceeded for embedding computation"); + } +} + +pub struct RateLimitedCompletion { + completion: Box, + rate_limiter: Ratelimiter, +} + +impl RateLimitedCompletion { + pub fn new(completion: Box, rpm: u64) -> Self { + Self { + completion, + rate_limiter: new_rate_limiter(rpm).unwrap(), + } + } +} + +#[async_trait] +impl CompletionStream for RateLimitedCompletion { + async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream { + for _ in 0..5 { + if let Err(sleep) = self.rate_limiter.try_wait() { + tokio::time::sleep(sleep).await; + continue; + } + + return self.completion.generate(prompt, options).await; + } + + // Return an empty stream if the rate limit is exceeded + Box::pin(futures::stream::empty()) + } +} + +pub struct RateLimitedChatStream { + completion: Box, + rate_limiter: Ratelimiter, +} + +impl RateLimitedChatStream { + pub fn new(completion: Box, rpm: u64) -> Self { + Self { + completion, + rate_limiter: new_rate_limiter(rpm).unwrap(), + } + } +} + +#[async_trait] +impl ChatCompletionStream for RateLimitedChatStream { + async fn chat( + &self, + request: CreateChatCompletionRequest, + ) -> Result { + for _ in 0..5 { + if let Err(sleep) = self.rate_limiter.try_wait() { + tokio::time::sleep(sleep).await; + continue; + } + + return self.completion.chat(request).await; + } + + Err(OpenAIError::ApiError(ApiError { + message: "Rate limit exceeded for chat completion".to_owned(), + r#type: None, + param: None, + code: None, + })) + } + + async fn chat_stream( + &self, + request: CreateChatCompletionRequest, + ) -> Result { + for _ in 0..5 { + if let Err(sleep) = self.rate_limiter.try_wait() { + tokio::time::sleep(sleep).await; + continue; + } + + return self.completion.chat_stream(request).await; + } + + Err(OpenAIError::ApiError(ApiError { + message: "Rate limit exceeded for chat completion".to_owned(), + r#type: None, + param: None, + code: None, + })) + } +} From eadfd66e1d13c7d68b4efb05ed936cd5f4a9bea8 Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Tue, 26 Nov 2024 14:48:06 +0800 Subject: [PATCH 4/4] chore: only expose impl for rate limited --- crates/http-api-bindings/src/chat/mod.rs | 2 +- .../http-api-bindings/src/completion/mod.rs | 2 +- crates/http-api-bindings/src/embedding/mod.rs | 2 +- crates/http-api-bindings/src/rate_limit.rs | 30 ++++++++----------- 4 files changed, 15 insertions(+), 21 deletions(-) diff --git a/crates/http-api-bindings/src/chat/mod.rs b/crates/http-api-bindings/src/chat/mod.rs index 604ed81fd476..f30a36ed1dca 100644 --- a/crates/http-api-bindings/src/chat/mod.rs +++ b/crates/http-api-bindings/src/chat/mod.rs @@ -38,7 +38,7 @@ pub async fn create(model: &HttpModelConfig) -> Arc { .with_http_client(create_reqwest_client(api_endpoint)), ); - Arc::new(rate_limit::RateLimitedChatStream::new( + Arc::new(rate_limit::new_chat( engine, model.rate_limit.request_per_minute, )) diff --git a/crates/http-api-bindings/src/completion/mod.rs b/crates/http-api-bindings/src/completion/mod.rs index d98caaef1100..3f7244d9a230 100644 --- a/crates/http-api-bindings/src/completion/mod.rs +++ b/crates/http-api-bindings/src/completion/mod.rs @@ -51,7 +51,7 @@ pub async fn create(model: &HttpModelConfig) -> Arc { ), }; - Arc::new(rate_limit::RateLimitedCompletion::new( + Arc::new(rate_limit::new_completion( engine, model.rate_limit.request_per_minute, )) diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index 8e19d22e99fd..f6e3c9695b60 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -47,7 +47,7 @@ pub async fn create(config: &HttpModelConfig) -> Arc { ), }; - Arc::new(rate_limit::RateLimitedEmbedding::new( + Arc::new(rate_limit::new_embedding( engine, config.rate_limit.request_per_minute, )) diff --git a/crates/http-api-bindings/src/rate_limit.rs b/crates/http-api-bindings/src/rate_limit.rs index 30d009b805a7..c9cc0e269adc 100644 --- a/crates/http-api-bindings/src/rate_limit.rs +++ b/crates/http-api-bindings/src/rate_limit.rs @@ -29,12 +29,10 @@ pub struct RateLimitedEmbedding { rate_limiter: Ratelimiter, } -impl RateLimitedEmbedding { - pub fn new(embedding: Box, rpm: u64) -> Self { - Self { - embedding, - rate_limiter: new_rate_limiter(rpm).unwrap(), - } +pub fn new_embedding(embedding: Box, rpm: u64) -> impl Embedding { + RateLimitedEmbedding { + embedding, + rate_limiter: new_rate_limiter(rpm).unwrap(), } } @@ -59,12 +57,10 @@ pub struct RateLimitedCompletion { rate_limiter: Ratelimiter, } -impl RateLimitedCompletion { - pub fn new(completion: Box, rpm: u64) -> Self { - Self { - completion, - rate_limiter: new_rate_limiter(rpm).unwrap(), - } +pub fn new_completion(completion: Box, rpm: u64) -> impl CompletionStream { + RateLimitedCompletion { + completion, + rate_limiter: new_rate_limiter(rpm).unwrap(), } } @@ -90,12 +86,10 @@ pub struct RateLimitedChatStream { rate_limiter: Ratelimiter, } -impl RateLimitedChatStream { - pub fn new(completion: Box, rpm: u64) -> Self { - Self { - completion, - rate_limiter: new_rate_limiter(rpm).unwrap(), - } +pub fn new_chat(completion: Box, rpm: u64) -> impl ChatCompletionStream { + RateLimitedChatStream { + completion, + rate_limiter: new_rate_limiter(rpm).unwrap(), } }