Skip to content
This repository has been archived by the owner on Feb 8, 2024. It is now read-only.

Commit

Permalink
Add non-cluster client
Browse files Browse the repository at this point in the history
  • Loading branch information
Ellie Huxtable committed Oct 19, 2023
1 parent 4ae29b9 commit f6e3d9c
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 14 deletions.
7 changes: 3 additions & 4 deletions capture-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::env;
use std::net::SocketAddr;
use std::sync::Arc;

use capture::{billing_limits::BillingLimiter, redis::RedisClusterClient, router, sink};
use capture::{billing_limits::BillingLimiter, redis::RedisClient, router, sink};
use time::Duration;
use tokio::signal;

Expand All @@ -25,11 +25,10 @@ async fn shutdown() {
async fn main() {
let use_print_sink = env::var("PRINT_SINK").is_ok();
let address = env::var("ADDRESS").unwrap_or(String::from("127.0.0.1:3000"));
let redis_addr = env::var("REDIS").expect("redis required; please set the REDIS env var to a comma-separated list of addresses ('one,two,three')");
let redis_addr = env::var("REDIS").expect("redis required; please set the REDIS env var");

let redis_nodes: Vec<String> = redis_addr.split(',').map(str::to_string).collect();
let redis_client =
Arc::new(RedisClusterClient::new(redis_nodes).expect("failed to create redis client"));
Arc::new(RedisClient::new(redis_addr).expect("failed to create redis client"));

let billing = BillingLimiter::new(Duration::seconds(5), redis_client.clone())
.expect("failed to create billing limiter");
Expand Down
8 changes: 4 additions & 4 deletions capture/src/billing_limits.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{collections::HashSet, ops::Sub, sync::Arc};

use crate::redis::RedisClient;
use crate::redis::Client;

/// Limit accounts by team ID if they hit a billing limit
///
Expand Down Expand Up @@ -49,7 +49,7 @@ pub enum LimiterError {
#[derive(Clone)]
pub struct BillingLimiter {
limited: Arc<RwLock<HashSet<String>>>,
redis: Arc<dyn RedisClient + Send + Sync>,
redis: Arc<dyn Client + Send + Sync>,
interval: Duration,
updated: Arc<RwLock<time::OffsetDateTime>>,
}
Expand All @@ -65,7 +65,7 @@ impl BillingLimiter {
/// Pass an empty redis node list to only use this initial set.
pub fn new(
interval: Duration,
redis: Arc<dyn RedisClient + Send + Sync>,
redis: Arc<dyn Client + Send + Sync>,
) -> anyhow::Result<BillingLimiter> {
let limited = Arc::new(RwLock::new(HashSet::new()));

Expand All @@ -82,7 +82,7 @@ impl BillingLimiter {
}

async fn fetch_limited(
client: &Arc<dyn RedisClient + Send + Sync>,
client: &Arc<dyn Client + Send + Sync>,
resource: QuotaResource,
) -> anyhow::Result<Vec<String>> {
let now = time::OffsetDateTime::now_utc().unix_timestamp();
Expand Down
30 changes: 27 additions & 3 deletions capture/src/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const REDIS_TIMEOUT_MILLISECS: u64 = 10;
/// awkward to work with.
#[async_trait]
pub trait RedisClient {
pub trait Client {
// A very simplified wrapper, but works for our usage
async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result<Vec<String>>;
}
Expand All @@ -34,7 +34,31 @@ impl RedisClusterClient {
}

#[async_trait]
impl RedisClient for RedisClusterClient {
impl Client for RedisClusterClient {
async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result<Vec<String>> {
let mut conn = self.client.get_async_connection().await?;

let results = conn.zrangebyscore(k, min, max);
let fut = timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?;

Ok(fut?)
}
}

pub struct RedisClient {
client: redis::Client,
}

impl RedisClient {
pub fn new(addr: String) -> Result<RedisClient> {
let client = redis::Client::open(addr)?;

Ok(RedisClient { client })
}
}

#[async_trait]
impl Client for RedisClient {
async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result<Vec<String>> {
let mut conn = self.client.get_async_connection().await?;

Expand Down Expand Up @@ -72,7 +96,7 @@ impl Default for MockRedisClient {
}

#[async_trait]
impl RedisClient for MockRedisClient {
impl Client for MockRedisClient {
// A very simplified wrapper, but works for our usage
async fn zrangebyscore(&self, _k: String, _min: String, _max: String) -> Result<Vec<String>> {
Ok(self.zrangebyscore_ret.clone())
Expand Down
6 changes: 3 additions & 3 deletions capture/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ use axum::{
};
use tower_http::trace::TraceLayer;

use crate::{billing_limits::BillingLimiter, capture, redis::RedisClient, sink, time::TimeSource};
use crate::{billing_limits::BillingLimiter, capture, redis::Client, sink, time::TimeSource};

use crate::prometheus::{setup_metrics_recorder, track_metrics};

#[derive(Clone)]
pub struct State {
pub sink: Arc<dyn sink::EventSink + Send + Sync>,
pub timesource: Arc<dyn TimeSource + Send + Sync>,
pub redis: Arc<dyn RedisClient + Send + Sync>,
pub redis: Arc<dyn Client + Send + Sync>,
pub billing: BillingLimiter,
}

Expand All @@ -26,7 +26,7 @@ async fn index() -> &'static str {
pub fn router<
TZ: TimeSource + Send + Sync + 'static,
S: sink::EventSink + Send + Sync + 'static,
R: RedisClient + Send + Sync + 'static,
R: Client + Send + Sync + 'static,
>(
timesource: TZ,
sink: S,
Expand Down
190 changes: 190 additions & 0 deletions capture/tests/billing_limiter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
use assert_json_diff::assert_json_matches_no_panic;
use async_trait::async_trait;
use axum::http::StatusCode;
use axum_test_helper::TestClient;
use base64::engine::general_purpose;
use base64::Engine;
use capture::api::{CaptureError, CaptureResponse, CaptureResponseCode};
use capture::billing_limits::BillingLimiter;
use capture::event::ProcessedEvent;
use capture::redis::MockRedisClient;
use capture::router::router;
use capture::sink::EventSink;
use capture::time::TimeSource;
use serde::Deserialize;
use serde_json::{json, Value};
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::sync::{Arc, Mutex};
use time::format_description::well_known::{Iso8601, Rfc3339};
use time::{Duration, OffsetDateTime};

#[derive(Debug, Deserialize)]
struct RequestDump {
path: String,
method: String,
content_encoding: String,
content_type: String,
ip: String,
now: String,
body: String,
output: Vec<Value>,
}

static REQUESTS_DUMP_FILE_NAME: &str = "tests/requests_dump.jsonl";

#[derive(Clone)]
pub struct FixedTime {
pub time: String,
}

impl TimeSource for FixedTime {
fn current_time(&self) -> String {
self.time.to_string()
}
}

#[derive(Clone, Default)]
struct MemorySink {
events: Arc<Mutex<Vec<ProcessedEvent>>>,
}

impl MemorySink {
fn len(&self) -> usize {
self.events.lock().unwrap().len()
}

fn events(&self) -> Vec<ProcessedEvent> {
self.events.lock().unwrap().clone()
}
}

#[async_trait]
impl EventSink for MemorySink {
async fn send(&self, event: ProcessedEvent) -> Result<(), CaptureError> {
self.events.lock().unwrap().push(event);
Ok(())
}

async fn send_batch(&self, events: Vec<ProcessedEvent>) -> Result<(), CaptureError> {
self.events.lock().unwrap().extend_from_slice(&events);
Ok(())
}
}

#[tokio::test]
async fn it_matches_django_capture_behaviour() -> anyhow::Result<()> {
let file = File::open(REQUESTS_DUMP_FILE_NAME)?;
let reader = BufReader::new(file);

let mut mismatches = 0;

for (line_number, line_contents) in reader.lines().enumerate() {
let case: RequestDump = serde_json::from_str(&line_contents?)?;
if !case.path.starts_with("/e/") {
println!("Skipping {} test case", &case.path);
continue;
}

let raw_body = general_purpose::STANDARD.decode(&case.body)?;
assert_eq!(
case.method, "POST",
"update code to handle method {}",
case.method
);

let sink = MemorySink::default();
let timesource = FixedTime { time: case.now };

let redis = Arc::new(MockRedisClient::new());
let billing = BillingLimiter::new(Duration::weeks(1), redis.clone())
.expect("failed to create billing limiter");

let app = router(timesource, sink.clone(), redis, billing, false);

let client = TestClient::new(app);
let mut req = client.post(&format!("/i/v0{}", case.path)).body(raw_body);
if !case.content_encoding.is_empty() {
req = req.header("Content-encoding", case.content_encoding);
}
if !case.content_type.is_empty() {
req = req.header("Content-type", case.content_type);
}
if !case.ip.is_empty() {
req = req.header("X-Forwarded-For", case.ip);
}

let res = req.send().await;
assert_eq!(
res.status(),
StatusCode::OK,
"line {} rejected: {}",
line_number,
res.text().await
);
assert_eq!(
Some(CaptureResponse {
status: CaptureResponseCode::Ok
}),
res.json().await
);
assert_eq!(
sink.len(),
case.output.len(),
"event count mismatch on line {}",
line_number
);

for (event_number, (message, expected)) in
sink.events().iter().zip(case.output.iter()).enumerate()
{
// Normalizing the expected event to align with known django->rust inconsistencies
let mut expected = expected.clone();
if let Some(value) = expected.get_mut("sent_at") {
// Default ISO format is different between python and rust, both are valid
// Parse and re-print the value before comparison
let sent_at =
OffsetDateTime::parse(value.as_str().expect("empty"), &Iso8601::DEFAULT)?;
*value = Value::String(sent_at.format(&Rfc3339)?)
}
if let Some(expected_data) = expected.get_mut("data") {
// Data is a serialized JSON map. Unmarshall both and compare them,
// instead of expecting the serialized bytes to be equal
let expected_props: Value =
serde_json::from_str(expected_data.as_str().expect("not str"))?;
let found_props: Value = serde_json::from_str(&message.data)?;
let match_config =
assert_json_diff::Config::new(assert_json_diff::CompareMode::Strict);
if let Err(e) =
assert_json_matches_no_panic(&expected_props, &found_props, match_config)
{
println!(
"data field mismatch at line {}, event {}: {}",
line_number, event_number, e
);
mismatches += 1;
} else {
*expected_data = json!(&message.data)
}
}

if let Some(object) = expected.as_object_mut() {
// site_url is unused in the pipeline now, let's drop it
object.remove("site_url");
}

let match_config = assert_json_diff::Config::new(assert_json_diff::CompareMode::Strict);
if let Err(e) =
assert_json_matches_no_panic(&json!(expected), &json!(message), match_config)
{
println!(
"record mismatch at line {}, event {}: {}",
line_number, event_number, e
);
mismatches += 1;
}
}
}
assert_eq!(0, mismatches, "some events didn't match");
Ok(())
}

0 comments on commit f6e3d9c

Please sign in to comment.