From 33fb9f0f88309e6756c25a2998c4702d6099f896 Mon Sep 17 00:00:00 2001 From: Ellie Huxtable Date: Tue, 17 Oct 2023 21:48:55 +0100 Subject: [PATCH] hook it all up --- Cargo.lock | 1 + capture-server/Cargo.toml | 1 + capture-server/src/main.rs | 28 +++++- capture/src/api.rs | 12 +++ capture/src/billing_limits.rs | 173 +++++++++++++++++---------------- capture/src/capture.rs | 24 ++++- capture/src/redis.rs | 14 ++- capture/src/router.rs | 9 +- capture/tests/django_compat.rs | 11 ++- 9 files changed, 176 insertions(+), 97 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f6c95f4..7ea2f9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -233,6 +233,7 @@ version = "0.1.0" dependencies = [ "axum", "capture", + "time", "tokio", "tracing", "tracing-subscriber", diff --git a/capture-server/Cargo.toml b/capture-server/Cargo.toml index 04c6182..6378532 100644 --- a/capture-server/Cargo.toml +++ b/capture-server/Cargo.toml @@ -9,3 +9,4 @@ axum = { workspace = true } tokio = { workspace = true } tracing-subscriber = { workspace = true } tracing = { workspace = true } +time = { workspace = true } diff --git a/capture-server/src/main.rs b/capture-server/src/main.rs index e2232d5..24036a8 100644 --- a/capture-server/src/main.rs +++ b/capture-server/src/main.rs @@ -1,7 +1,9 @@ use std::env; use std::net::SocketAddr; +use std::sync::Arc; -use capture::{router, sink, time}; +use capture::{billing_limits::BillingLimiter, redis::RedisClusterClient, router, sink}; +use time::Duration; use tokio::signal; async fn shutdown() { @@ -23,16 +25,36 @@ async fn shutdown() { async fn main() { let use_print_sink = env::var("PRINT_SINK").is_ok(); let address = env::var("ADDRESS").unwrap_or(String::from("127.0.0.1:3000")); + let redis_addr = env::var("REDIS").expect("redis required; please set the REDIS env var to a comma-separated list of addresses ('one,two,three')"); + + let redis_nodes: Vec = redis_addr.split(',').map(str::to_string).collect(); + let redis_client = + Arc::new(RedisClusterClient::new(redis_nodes).expect("failed to create redis client")); + + let billing = BillingLimiter::new(Duration::seconds(5), redis_client.clone()) + .expect("failed to create billing limiter"); let app = if use_print_sink { - router::router(time::SystemTime {}, sink::PrintSink {}, true) + router::router( + capture::time::SystemTime {}, + sink::PrintSink {}, + redis_client, + billing, + true, + ) } else { let brokers = env::var("KAFKA_BROKERS").expect("Expected KAFKA_BROKERS"); let topic = env::var("KAFKA_TOPIC").expect("Expected KAFKA_TOPIC"); let sink = sink::KafkaSink::new(topic, brokers).unwrap(); - router::router(time::SystemTime {}, sink, true) + router::router( + capture::time::SystemTime {}, + sink, + redis_client, + billing, + true, + ) }; // initialize tracing diff --git a/capture/src/api.rs b/capture/src/api.rs index 319056c..ff245b5 100644 --- a/capture/src/api.rs +++ b/capture/src/api.rs @@ -52,6 +52,12 @@ pub enum CaptureError { EventTooBig, #[error("invalid event could not be processed")] NonRetryableSinkError, + + #[error("billing limit reached")] + BillingLimit, + + #[error("rate limited")] + RateLimited, } impl IntoResponse for CaptureError { @@ -64,10 +70,16 @@ impl IntoResponse for CaptureError { | CaptureError::MissingDistinctId | CaptureError::EventTooBig | CaptureError::NonRetryableSinkError => (StatusCode::BAD_REQUEST, self.to_string()), + CaptureError::NoTokenError | CaptureError::MultipleTokensError | CaptureError::TokenValidationError(_) => (StatusCode::UNAUTHORIZED, self.to_string()), + CaptureError::RetryableSinkError => (StatusCode::SERVICE_UNAVAILABLE, self.to_string()), + + CaptureError::BillingLimit | CaptureError::RateLimited => { + (StatusCode::TOO_MANY_REQUESTS, self.to_string()) + } } .into_response() } diff --git a/capture/src/billing_limits.rs b/capture/src/billing_limits.rs index 73ff07b..5d873e0 100644 --- a/capture/src/billing_limits.rs +++ b/capture/src/billing_limits.rs @@ -1,7 +1,7 @@ -use std::{collections::HashSet, sync::Arc, ops::Sub}; +use std::{collections::HashSet, ops::Sub, sync::Arc}; + +use crate::redis::RedisClient; -use crate::redis::{RedisClient, RedisClusterClient, MockRedisClient}; -use async_trait::async_trait; /// Limit accounts by team ID if they hit a billing limit /// /// We have an async celery worker that regularly checks on accounts + assesses if they are beyond @@ -16,21 +16,32 @@ use async_trait::async_trait; /// 2. Capture should cope with redis being _totally down_, and fail open /// 3. We should not hit redis for every single request /// -/// The solution here is -/// -/// 1. A background task to regularly pull in the latest set from redis -/// 2. A cached store of the most recently known set, so we don't need to keep hitting redis +/// The solution here is to read from the cache until a time interval is hit, and then fetch new +/// data. The write requires taking a lock that stalls all readers, though so long as redis reads +/// stay fast we're ok. /// /// Some small delay between an account being limited and the limit taking effect is acceptable. /// However, ideally we should not allow requests from some pods but 429 from others. use thiserror::Error; -use time::{Duration, OffsetDateTime, UtcOffset}; -use tokio::sync::{Mutex, RwLock}; -use tracing::Level; +use time::{Duration, OffsetDateTime}; +use tokio::sync::RwLock; // todo: fetch from env -const UPDATE_INTERVAL_SECS: u64 = 5; -const QUOTA_LIMITER_CACHE_KEY: &'static str = "@posthog/quota-limits/"; +const QUOTA_LIMITER_CACHE_KEY: &str = "@posthog/quota-limits/"; + +pub enum QuotaResource { + Events, + Recordings, +} + +impl QuotaResource { + fn as_str(&self) -> &'static str { + match self { + Self::Events => "events", + Self::Recordings => "recordings", + } + } +} #[derive(Error, Debug)] pub enum LimiterError { @@ -38,8 +49,10 @@ pub enum LimiterError { UpdaterRunning, } -struct BillingLimiter { +#[derive(Clone)] +pub struct BillingLimiter { limited: Arc>>, + redis: Arc, interval: Duration, updated: Arc>, } @@ -54,130 +67,120 @@ impl BillingLimiter { /// /// Pass an empty redis node list to only use this initial set. pub fn new( - limited: Option>, - interval: Option, + interval: Duration, + redis: Arc, ) -> anyhow::Result { - let limited = limited.unwrap_or_else(|| HashSet::new()); - let limited = Arc::new(RwLock::new(limited)); + let limited = Arc::new(RwLock::new(HashSet::new())); // Force an update immediately if we have any reasonable interval let updated = OffsetDateTime::from_unix_timestamp(0)?; let updated = Arc::new(RwLock::new(updated)); - // Default to an interval that's so long, we will never update. If this code is still - // running in 99yrs that's pretty cool. - let interval = interval.unwrap_or_else(||Duration::weeks(99 * 52)); - Ok(BillingLimiter { interval, limited, updated, + redis, }) } - async fn fetch_limited(client: &impl RedisClient) -> anyhow::Result> { + async fn fetch_limited( + client: &Arc, + resource: QuotaResource, + ) -> anyhow::Result> { let now = time::OffsetDateTime::now_utc().unix_timestamp(); + // todo: timeout on external calls client .zrangebyscore( - format!("{QUOTA_LIMITER_CACHE_KEY}events"), + format!("{QUOTA_LIMITER_CACHE_KEY}{}", resource.as_str()), now.to_string(), String::from("+Inf"), ) .await } - pub async fn is_limited(&self, key: &str, client: &impl RedisClient) -> bool { + pub async fn is_limited(&self, key: &str, resource: QuotaResource) -> bool { // hold the read lock to clone it, very briefly. clone is ok because it's very small 🤏 // rwlock can have many readers, but one writer. the writer will wait in a queue with all // the readers, so we want to hold read locks for the smallest time possible to avoid // writers waiting for too long. and vice versa. let updated = { let updated = self.updated.read().await; - updated.clone() + *updated }; let now = OffsetDateTime::now_utc(); let since_update = now.sub(updated); + // If an update is due, fetch the set from redis + cache it until the next update is due. + // Otherwise, return a value from the cache + // + // This update will block readers! Keep it fast. if since_update > self.interval { - let set = Self::fetch_limited(client).await; - let set = HashSet::from_iter(set.unwrap().iter().cloned()); - - let mut limited = self.limited.write().await; - *limited = set; - - return limited.contains(key); - } - - let l = self.limited.read().await; - - l.contains(key) + let set = Self::fetch_limited(&self.redis, resource).await; + + // Update regardless of success here. It does mean we will keep trying to hit redis on + // our interval, but that's probably OK for now. + { + let mut updated = self.updated.write().await; + *updated = now; + } + + if let Ok(set) = set { + let set = HashSet::from_iter(set.iter().cloned()); + + let mut limited = self.limited.write().await; + *limited = set; + + limited.contains(key) + } else { + // If we fail to fetch the set, something really wrong is happening. To avoid + // dropping events that we don't mean to drop, fail open and accept data. Better + // than angry customers :) + // + // TODO: Consider backing off our redis checks + false + } + } else { + let l = self.limited.read().await; + + l.contains(key) + } } } #[cfg(test)] mod tests { + use std::sync::Arc; use time::Duration; - use crate::{billing_limits::BillingLimiter, redis::{MockRedisClient, RedisClient}}; + use crate::{ + billing_limits::{BillingLimiter, QuotaResource}, + redis::MockRedisClient, + }; - // Test that a token _not_ limited has no restriction applied - // Avoid messing up and accidentally limiting everyone #[tokio::test] - async fn test_not_limited() { - let client = MockRedisClient::new(); - let limiter = BillingLimiter::new(None, None).expect("Failed to create billing limiter"); - - assert_eq!( - limiter.is_limited("idk it doesn't matter", &client).await, - false - ); - } + async fn test_dynamic_limited() { + let client = MockRedisClient::new().zrangebyscore_ret(vec![String::from("banana")]); + let client = Arc::new(client); - // Test that a token _not_ limited has no restriction applied - // Avoid messing up and accidentally limiting everyone - #[tokio::test] - async fn test_fixed_limited() { - let client = MockRedisClient::new(); - - let limiter = BillingLimiter::new( - Some( - vec![String::from("some_org_hit_limits")] - .into_iter() - .collect(), - ), - None, - ) - .expect("Failed to create billing limiter"); + let limiter = BillingLimiter::new(Duration::microseconds(1), client) + .expect("Failed to create billing limiter"); assert_eq!( - limiter.is_limited("idk it doesn't matter", &client).await, + limiter + .is_limited("idk it doesn't matter", QuotaResource::Events) + .await, false ); - assert!(limiter.is_limited("some_org_hit_limits", &client).await); - } - - #[tokio::test] - async fn test_dynamic_limited() { - let client = MockRedisClient::new().zrangebyscore_ret(vec![String::from("banana")]); - - let limiter = BillingLimiter::new( - Some( - vec![String::from("some_org_hit_limits")] - .into_iter() - .collect(), - ), - Some(Duration::microseconds(1)), - ) - .expect("Failed to create billing limiter"); assert_eq!( - limiter.is_limited("idk it doesn't matter", &client).await, + limiter + .is_limited("some_org_hit_limits", QuotaResource::Events) + .await, false ); - - assert_eq!(limiter.is_limited("some_org_hit_limits", &client).await, false); - assert!(limiter.is_limited("banana", &client).await); + assert!(limiter.is_limited("banana", QuotaResource::Events).await); } } diff --git a/capture/src/capture.rs b/capture/src/capture.rs index 98a61d3..65e64c9 100644 --- a/capture/src/capture.rs +++ b/capture/src/capture.rs @@ -12,6 +12,7 @@ use axum_client_ip::InsecureClientIp; use base64::Engine; use time::OffsetDateTime; +use crate::billing_limits::QuotaResource; use crate::event::ProcessingContext; use crate::token::validate_token; use crate::{ @@ -44,7 +45,7 @@ pub async fn event( _ => RawEvent::from_bytes(&meta, body), }?; - println!("Got events {:?}", &events); + tracing::debug!("got events {:?}", &events); if events.is_empty() { return Err(CaptureError::EmptyBatch); @@ -61,6 +62,7 @@ pub async fn event( } None }); + let context = ProcessingContext { lib_version: meta.lib_version.clone(), sent_at, @@ -69,7 +71,25 @@ pub async fn event( client_ip: ip.to_string(), }; - println!("Got context {:?}", &context); + let limited = state + .billing + .is_limited(context.token.as_str(), QuotaResource::Events) + .await; + + if limited { + // for v0 we want to just return ok 🙃 + // this is because the clients are pretty dumb and will just retry over and over and + // over... + // + // for v1, we'll return a meaningful error code and error, so that the clients can do + // something meaningful with that error + + return Ok(Json(CaptureResponse { + status: CaptureResponseCode::Ok, + })); + } + + tracing::debug!("got context {:?}", &context); process_events(state.sink.clone(), &events, &context).await?; diff --git a/capture/src/redis.rs b/capture/src/redis.rs index a379220..1feb01a 100644 --- a/capture/src/redis.rs +++ b/capture/src/redis.rs @@ -41,12 +41,12 @@ impl RedisClient for RedisClusterClient { // mockall got really annoying with async and results so I'm just gonna do my own #[derive(Clone)] pub struct MockRedisClient { - zrangebyscore_ret: Vec + zrangebyscore_ret: Vec, } impl MockRedisClient { - pub fn new() -> MockRedisClient{ - MockRedisClient{ + pub fn new() -> MockRedisClient { + MockRedisClient { zrangebyscore_ret: Vec::new(), } } @@ -58,10 +58,16 @@ impl MockRedisClient { } } +impl Default for MockRedisClient { + fn default() -> Self { + Self::new() + } +} + #[async_trait] impl RedisClient for MockRedisClient { // A very simplified wrapper, but works for our usage - async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result>{ + async fn zrangebyscore(&self, _k: String, _min: String, _max: String) -> Result> { Ok(self.zrangebyscore_ret.clone()) } } diff --git a/capture/src/router.rs b/capture/src/router.rs index 0c40658..2a0bf24 100644 --- a/capture/src/router.rs +++ b/capture/src/router.rs @@ -7,7 +7,7 @@ use axum::{ }; use tower_http::trace::TraceLayer; -use crate::{capture, sink, time::TimeSource}; +use crate::{billing_limits::BillingLimiter, capture, redis::RedisClient, sink, time::TimeSource}; use crate::prometheus::{setup_metrics_recorder, track_metrics}; @@ -15,6 +15,8 @@ use crate::prometheus::{setup_metrics_recorder, track_metrics}; pub struct State { pub sink: Arc, pub timesource: Arc, + pub redis: Arc, + pub billing: BillingLimiter, } async fn index() -> &'static str { @@ -24,14 +26,19 @@ async fn index() -> &'static str { pub fn router< TZ: TimeSource + Send + Sync + 'static, S: sink::EventSink + Send + Sync + 'static, + R: RedisClient + Send + Sync + 'static, >( timesource: TZ, sink: S, + redis: Arc, + billing: BillingLimiter, metrics: bool, ) -> Router { let state = State { sink: Arc::new(sink), timesource: Arc::new(timesource), + redis, + billing, }; let router = Router::new() diff --git a/capture/tests/django_compat.rs b/capture/tests/django_compat.rs index 119777d..d418996 100644 --- a/capture/tests/django_compat.rs +++ b/capture/tests/django_compat.rs @@ -5,7 +5,9 @@ use axum_test_helper::TestClient; use base64::engine::general_purpose; use base64::Engine; use capture::api::{CaptureError, CaptureResponse, CaptureResponseCode}; +use capture::billing_limits::BillingLimiter; use capture::event::ProcessedEvent; +use capture::redis::MockRedisClient; use capture::router::router; use capture::sink::EventSink; use capture::time::TimeSource; @@ -15,7 +17,7 @@ use std::fs::File; use std::io::{BufRead, BufReader}; use std::sync::{Arc, Mutex}; use time::format_description::well_known::{Iso8601, Rfc3339}; -use time::OffsetDateTime; +use time::{Duration, OffsetDateTime}; #[derive(Debug, Deserialize)] struct RequestDump { @@ -93,7 +95,12 @@ async fn it_matches_django_capture_behaviour() -> anyhow::Result<()> { let sink = MemorySink::default(); let timesource = FixedTime { time: case.now }; - let app = router(timesource, sink.clone(), false); + + let redis = Arc::new(MockRedisClient::new()); + let billing = BillingLimiter::new(Duration::weeks(1), redis.clone()) + .expect("failed to create billing limiter"); + + let app = router(timesource, sink.clone(), redis, billing, false); let client = TestClient::new(app); let mut req = client.post(&format!("/i/v0{}", case.path)).body(raw_body);