Skip to content
This repository has been archived by the owner on Feb 8, 2024. It is now read-only.

Commit

Permalink
hook it all up
Browse files Browse the repository at this point in the history
  • Loading branch information
Ellie Huxtable committed Oct 17, 2023
1 parent e158e58 commit 33fb9f0
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 97 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions capture-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ axum = { workspace = true }
tokio = { workspace = true }
tracing-subscriber = { workspace = true }
tracing = { workspace = true }
time = { workspace = true }
28 changes: 25 additions & 3 deletions capture-server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::env;
use std::net::SocketAddr;
use std::sync::Arc;

use capture::{router, sink, time};
use capture::{billing_limits::BillingLimiter, redis::RedisClusterClient, router, sink};
use time::Duration;
use tokio::signal;

async fn shutdown() {
Expand All @@ -23,16 +25,36 @@ async fn shutdown() {
async fn main() {
let use_print_sink = env::var("PRINT_SINK").is_ok();
let address = env::var("ADDRESS").unwrap_or(String::from("127.0.0.1:3000"));
let redis_addr = env::var("REDIS").expect("redis required; please set the REDIS env var to a comma-separated list of addresses ('one,two,three')");

let redis_nodes: Vec<String> = redis_addr.split(',').map(str::to_string).collect();
let redis_client =
Arc::new(RedisClusterClient::new(redis_nodes).expect("failed to create redis client"));

let billing = BillingLimiter::new(Duration::seconds(5), redis_client.clone())
.expect("failed to create billing limiter");

let app = if use_print_sink {
router::router(time::SystemTime {}, sink::PrintSink {}, true)
router::router(
capture::time::SystemTime {},
sink::PrintSink {},
redis_client,
billing,
true,
)
} else {
let brokers = env::var("KAFKA_BROKERS").expect("Expected KAFKA_BROKERS");
let topic = env::var("KAFKA_TOPIC").expect("Expected KAFKA_TOPIC");

let sink = sink::KafkaSink::new(topic, brokers).unwrap();

router::router(time::SystemTime {}, sink, true)
router::router(
capture::time::SystemTime {},
sink,
redis_client,
billing,
true,
)
};

// initialize tracing
Expand Down
12 changes: 12 additions & 0 deletions capture/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ pub enum CaptureError {
EventTooBig,
#[error("invalid event could not be processed")]
NonRetryableSinkError,

#[error("billing limit reached")]
BillingLimit,

#[error("rate limited")]
RateLimited,
}

impl IntoResponse for CaptureError {
Expand All @@ -64,10 +70,16 @@ impl IntoResponse for CaptureError {
| CaptureError::MissingDistinctId
| CaptureError::EventTooBig
| CaptureError::NonRetryableSinkError => (StatusCode::BAD_REQUEST, self.to_string()),

CaptureError::NoTokenError
| CaptureError::MultipleTokensError
| CaptureError::TokenValidationError(_) => (StatusCode::UNAUTHORIZED, self.to_string()),

CaptureError::RetryableSinkError => (StatusCode::SERVICE_UNAVAILABLE, self.to_string()),

CaptureError::BillingLimit | CaptureError::RateLimited => {
(StatusCode::TOO_MANY_REQUESTS, self.to_string())
}
}
.into_response()
}
Expand Down
173 changes: 88 additions & 85 deletions capture/src/billing_limits.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{collections::HashSet, sync::Arc, ops::Sub};
use std::{collections::HashSet, ops::Sub, sync::Arc};

use crate::redis::RedisClient;

use crate::redis::{RedisClient, RedisClusterClient, MockRedisClient};
use async_trait::async_trait;
/// Limit accounts by team ID if they hit a billing limit
///
/// We have an async celery worker that regularly checks on accounts + assesses if they are beyond
Expand All @@ -16,30 +16,43 @@ use async_trait::async_trait;
/// 2. Capture should cope with redis being _totally down_, and fail open
/// 3. We should not hit redis for every single request
///
/// The solution here is
///
/// 1. A background task to regularly pull in the latest set from redis
/// 2. A cached store of the most recently known set, so we don't need to keep hitting redis
/// The solution here is to read from the cache until a time interval is hit, and then fetch new
/// data. The write requires taking a lock that stalls all readers, though so long as redis reads
/// stay fast we're ok.
///
/// Some small delay between an account being limited and the limit taking effect is acceptable.
/// However, ideally we should not allow requests from some pods but 429 from others.
use thiserror::Error;
use time::{Duration, OffsetDateTime, UtcOffset};
use tokio::sync::{Mutex, RwLock};
use tracing::Level;
use time::{Duration, OffsetDateTime};
use tokio::sync::RwLock;

// todo: fetch from env
const UPDATE_INTERVAL_SECS: u64 = 5;
const QUOTA_LIMITER_CACHE_KEY: &'static str = "@posthog/quota-limits/";
const QUOTA_LIMITER_CACHE_KEY: &str = "@posthog/quota-limits/";

pub enum QuotaResource {
Events,
Recordings,
}

impl QuotaResource {
fn as_str(&self) -> &'static str {
match self {
Self::Events => "events",
Self::Recordings => "recordings",
}
}
}

#[derive(Error, Debug)]
pub enum LimiterError {
#[error("updater already running - there can only be one")]
UpdaterRunning,
}

struct BillingLimiter {
#[derive(Clone)]
pub struct BillingLimiter {
limited: Arc<RwLock<HashSet<String>>>,
redis: Arc<dyn RedisClient + Send + Sync>,
interval: Duration,
updated: Arc<RwLock<time::OffsetDateTime>>,
}
Expand All @@ -54,130 +67,120 @@ impl BillingLimiter {
///
/// Pass an empty redis node list to only use this initial set.
pub fn new(
limited: Option<HashSet<String>>,
interval: Option<Duration >,
interval: Duration,
redis: Arc<dyn RedisClient + Send + Sync>,
) -> anyhow::Result<BillingLimiter> {
let limited = limited.unwrap_or_else(|| HashSet::new());
let limited = Arc::new(RwLock::new(limited));
let limited = Arc::new(RwLock::new(HashSet::new()));

// Force an update immediately if we have any reasonable interval
let updated = OffsetDateTime::from_unix_timestamp(0)?;
let updated = Arc::new(RwLock::new(updated));

// Default to an interval that's so long, we will never update. If this code is still
// running in 99yrs that's pretty cool.
let interval = interval.unwrap_or_else(||Duration::weeks(99 * 52));

Ok(BillingLimiter {
interval,
limited,
updated,
redis,
})
}

async fn fetch_limited(client: &impl RedisClient) -> anyhow::Result<Vec<String>> {
async fn fetch_limited(
client: &Arc<dyn RedisClient + Send + Sync>,
resource: QuotaResource,
) -> anyhow::Result<Vec<String>> {
let now = time::OffsetDateTime::now_utc().unix_timestamp();

// todo: timeout on external calls
client
.zrangebyscore(
format!("{QUOTA_LIMITER_CACHE_KEY}events"),
format!("{QUOTA_LIMITER_CACHE_KEY}{}", resource.as_str()),
now.to_string(),
String::from("+Inf"),
)
.await
}

pub async fn is_limited(&self, key: &str, client: &impl RedisClient) -> bool {
pub async fn is_limited(&self, key: &str, resource: QuotaResource) -> bool {
// hold the read lock to clone it, very briefly. clone is ok because it's very small 🤏
// rwlock can have many readers, but one writer. the writer will wait in a queue with all
// the readers, so we want to hold read locks for the smallest time possible to avoid
// writers waiting for too long. and vice versa.
let updated = {
let updated = self.updated.read().await;
updated.clone()
*updated
};

let now = OffsetDateTime::now_utc();
let since_update = now.sub(updated);

// If an update is due, fetch the set from redis + cache it until the next update is due.
// Otherwise, return a value from the cache
//
// This update will block readers! Keep it fast.
if since_update > self.interval {
let set = Self::fetch_limited(client).await;
let set = HashSet::from_iter(set.unwrap().iter().cloned());

let mut limited = self.limited.write().await;
*limited = set;

return limited.contains(key);
}

let l = self.limited.read().await;

l.contains(key)
let set = Self::fetch_limited(&self.redis, resource).await;

// Update regardless of success here. It does mean we will keep trying to hit redis on
// our interval, but that's probably OK for now.
{
let mut updated = self.updated.write().await;
*updated = now;
}

if let Ok(set) = set {
let set = HashSet::from_iter(set.iter().cloned());

let mut limited = self.limited.write().await;
*limited = set;

limited.contains(key)
} else {
// If we fail to fetch the set, something really wrong is happening. To avoid
// dropping events that we don't mean to drop, fail open and accept data. Better
// than angry customers :)
//
// TODO: Consider backing off our redis checks
false
}
} else {
let l = self.limited.read().await;

l.contains(key)
}
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
use time::Duration;

use crate::{billing_limits::BillingLimiter, redis::{MockRedisClient, RedisClient}};
use crate::{
billing_limits::{BillingLimiter, QuotaResource},
redis::MockRedisClient,
};

// Test that a token _not_ limited has no restriction applied
// Avoid messing up and accidentally limiting everyone
#[tokio::test]
async fn test_not_limited() {
let client = MockRedisClient::new();
let limiter = BillingLimiter::new(None, None).expect("Failed to create billing limiter");

assert_eq!(
limiter.is_limited("idk it doesn't matter", &client).await,
false
);
}
async fn test_dynamic_limited() {
let client = MockRedisClient::new().zrangebyscore_ret(vec![String::from("banana")]);
let client = Arc::new(client);

// Test that a token _not_ limited has no restriction applied
// Avoid messing up and accidentally limiting everyone
#[tokio::test]
async fn test_fixed_limited() {
let client = MockRedisClient::new();

let limiter = BillingLimiter::new(
Some(
vec![String::from("some_org_hit_limits")]
.into_iter()
.collect(),
),
None,
)
.expect("Failed to create billing limiter");
let limiter = BillingLimiter::new(Duration::microseconds(1), client)
.expect("Failed to create billing limiter");

assert_eq!(
limiter.is_limited("idk it doesn't matter", &client).await,
limiter
.is_limited("idk it doesn't matter", QuotaResource::Events)
.await,
false
);
assert!(limiter.is_limited("some_org_hit_limits", &client).await);
}

#[tokio::test]
async fn test_dynamic_limited() {
let client = MockRedisClient::new().zrangebyscore_ret(vec![String::from("banana")]);

let limiter = BillingLimiter::new(
Some(
vec![String::from("some_org_hit_limits")]
.into_iter()
.collect(),
),
Some(Duration::microseconds(1)),
)
.expect("Failed to create billing limiter");

assert_eq!(
limiter.is_limited("idk it doesn't matter", &client).await,
limiter
.is_limited("some_org_hit_limits", QuotaResource::Events)
.await,
false
);

assert_eq!(limiter.is_limited("some_org_hit_limits", &client).await, false);
assert!(limiter.is_limited("banana", &client).await);
assert!(limiter.is_limited("banana", QuotaResource::Events).await);
}
}
Loading

0 comments on commit 33fb9f0

Please sign in to comment.