Skip to content

Commit

Permalink
Enfore embedder rate limit margin for upserts
Browse files Browse the repository at this point in the history
  • Loading branch information
tdraier committed Mar 4, 2025
1 parent 89bbd48 commit be1d380
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 8 deletions.
50 changes: 49 additions & 1 deletion core/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,4 @@ unicode-normalization = "0.1.24"
dateparser = "0.2.1"
once_cell = "1.18"
redis = { version = "0.24.0", features = ["tokio-comp"] }
parse_duration = "2.1.1"
13 changes: 8 additions & 5 deletions core/src/data_sources/data_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use qdrant_client::qdrant::vectors::VectorsOptions;
use qdrant_client::qdrant::{PointId, RetrievedPoint, ScoredPoint};
use qdrant_client::{prelude::Payload, qdrant};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::fmt;
use std::str::FromStr;
Expand Down Expand Up @@ -98,13 +98,13 @@ pub struct Chunk {
/// corresponding table)
///
/// For some sources, this is well emboodied by the parent's external id,
/// provided by the managed datasources API: the Notion id (notionPageId column
/// provided by the managed datasource's API: the Notion id (notionPageId column
/// in `notion_pages`) for Notion pages and databases, the Google drive id
/// (driveFileId column in `google_drive_documents`).
///
/// For other sources, such as github: github issues / discussions do not have a
/// proper external id, so we use our computed document id. The repo is
/// considered a parent, and has a proper external repo id, which is stored at
/// considered a parent, and has a proper external "repo id", which is stored at
/// 2nd place in the array
///
/// Additional note: in cases where selection of elements to sync is done on
Expand All @@ -129,7 +129,7 @@ pub struct Chunk {
///
/// The id of the document itself is stored at index 0 because the field is used
/// in filtering search to search only parts of the hierarchy: it is natural
/// that if the documents id is selected as a parent filter, the document
/// that if the document's id is selected as a parent filter, the document
/// itself shows up in the search.
///
///
Expand Down Expand Up @@ -986,13 +986,16 @@ impl DataSource {
.map(|chunk| chunk.to_vec())
.collect::<Vec<_>>();

let mut extras = self.config.extras.clone().unwrap_or(json!({}));
extras["enforce_rate_limit_margin"] = json!(true);

// Embed batched chunks sequentially.
for chunk in chunked_splits {
let r = EmbedderRequest::new(
embedder_config.provider_id.clone(),
&embedder_config.model_id,
chunk.iter().map(|ci| ci.text.as_str()).collect::<Vec<_>>(),
self.config.extras.clone(),
Some(extras.clone()),
);

let v = match r.execute(credentials.clone()).await {
Expand Down
11 changes: 9 additions & 2 deletions core/src/providers/azure_openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use crate::providers::embedder::{Embedder, EmbedderVector};
use crate::providers::llm::ChatFunction;
use crate::providers::llm::Tokens;
use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM};
use crate::providers::openai::completion;
use crate::providers::openai::embed;
use crate::providers::openai::streamed_completion;
use crate::providers::openai::{completion, REMAINING_TOKENS_MARGIN};
use crate::providers::provider::{Provider, ProviderID};
use crate::providers::tiktoken::tiktoken::{batch_tokenize_async, decode_async, encode_async};
use crate::providers::tiktoken::tiktoken::{
Expand Down Expand Up @@ -601,13 +601,20 @@ impl Embedder for AzureOpenAIEmbedder {
None,
Some(self.model_id.clone()),
text,
match extras {
match &extras {
Some(e) => match e.get("openai_user") {
Some(u) => Some(u.to_string()),
None => None,
},
None => None,
},
match &extras {
Some(e) => match e.get("enforce_rate_limit_margin") {
Some(Value::Bool(true)) => Some(REMAINING_TOKENS_MARGIN),
_ => None,
},
None => None,
},
)
.await?;

Expand Down
68 changes: 68 additions & 0 deletions core/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ use futures::TryStreamExt;
use hyper::StatusCode;
use hyper::{body::Buf, Uri};
use itertools::izip;
use lazy_static::lazy_static;
use parking_lot::{Mutex, RwLock};
use parse_duration::parse;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::Value;
Expand All @@ -35,6 +37,18 @@ use super::openai_compatible_helpers::{
openai_compatible_chat_completion, OpenAIError, TransformSystemMessages,
};

pub const REMAINING_TOKENS_MARGIN: u64 = 500_000;
#[derive(Debug)]
struct RateLimitDetails {
pub remaining_tokens: u64,
pub reset_tokens: u64, // Unix timestamp in milliseconds when the rate limit resets
}

lazy_static! {
// Map of API key to rate limit details
static ref RATE_LIMITS: Mutex<HashMap<String, RateLimitDetails>> = Mutex::new(HashMap::new());
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Usage {
pub prompt_tokens: u64,
Expand Down Expand Up @@ -523,7 +537,31 @@ pub async fn embed(
model_id: Option<String>,
text: Vec<&str>,
user: Option<String>,
min_remaining_tokens: Option<u64>,
) -> Result<Embeddings> {
if let Some(min_remaining_tokens) = min_remaining_tokens {
let mut rate_limits = RATE_LIMITS.lock();

// Clean up expired rate limits
let now = utils::now();
rate_limits.retain(|_, details| details.reset_tokens > now);

// Check rate limit for this API key
if let Some(details) = rate_limits.get(&api_key) {
if details.reset_tokens > now && details.remaining_tokens < min_remaining_tokens {
Err(ModelError {
request_id: None,
message: "Rate limit exceeded".to_string(),
retryable: Some(ModelErrorRetryOptions {
sleep: Duration::from_millis(details.reset_tokens - now),
factor: 2,
retries: 3,
}),
})?;
}
}
}

let mut body = json!({
"input": text,
});
Expand Down Expand Up @@ -569,6 +607,29 @@ pub async fn embed(
None => None,
};

let remaining_tokens = match res_headers.get("x-ratelimit-remaining-tokens") {
Some(remaining_tokens) => remaining_tokens.to_str()?.to_string().parse::<u64>().ok(),
None => None,
};

let reset_tokens = match res_headers.get("x-ratelimit-reset-tokens") {
Some(reset_tokens) => parse(reset_tokens.to_str()?).ok().map(|d| d.as_millis()),
None => None,
};
match (remaining_tokens, reset_tokens) {
(Some(remaining_tokens), Some(reset_tokens)) => {
let now = utils::now();
RATE_LIMITS.lock().insert(
api_key.clone(),
RateLimitDetails {
remaining_tokens,
reset_tokens: now + reset_tokens as u64,
},
);
}
_ => (),
}

let body = match timeout(Duration::new(60, 0), res.bytes()).await {
Ok(Ok(body)) => body,
Ok(Err(e)) => Err(e)?,
Expand Down Expand Up @@ -1099,6 +1160,13 @@ impl Embedder for OpenAIEmbedder {
},
None => None,
},
match &extras {
Some(e) => match e.get("enforce_rate_limit_margin") {
Some(Value::Bool(true)) => Some(REMAINING_TOKENS_MARGIN),
_ => None,
},
None => None,
},
)
.await?;

Expand Down

0 comments on commit be1d380

Please sign in to comment.