Skip to content

Commit

Permalink
chore(flags): cleaning up some flags stuff (#26444)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmarticus authored Nov 27, 2024
1 parent c0bf8cf commit f6de867
Show file tree
Hide file tree
Showing 11 changed files with 545 additions and 699 deletions.
2 changes: 1 addition & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,4 @@ ahash = "0.8.11"
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
aws-sdk-s3 = "1.58.0"
mockall = "0.13.0"
moka = { version = "0.12.8", features = ["sync"] }
moka = { version = "0.12.8", features = ["sync", "future"] }
2 changes: 1 addition & 1 deletion rust/feature-flags/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ common-metrics = { path = "../common/metrics" }
tower = { workspace = true }
derive_builder = "0.20.1"
petgraph = "0.6.5"
moka = { version = "0.12.8", features = ["future"] }
moka = { workspace = true }

[lints]
workspace = true
Expand Down
66 changes: 32 additions & 34 deletions rust/feature-flags/src/api/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ pub struct FeatureFlagEvaluationContext {
team_id: i32,
distinct_id: String,
feature_flags: FeatureFlagList,
postgres_reader: Arc<dyn Client + Send + Sync>,
postgres_writer: Arc<dyn Client + Send + Sync>,
reader: Arc<dyn Client + Send + Sync>,
writer: Arc<dyn Client + Send + Sync>,
cohort_cache: Arc<CohortCacheManager>,
#[builder(default)]
person_property_overrides: Option<HashMap<String, Value>>,
Expand All @@ -93,10 +93,10 @@ pub async fn process_request(context: RequestContext) -> Result<FlagsResponse, F

let request = decode_request(&headers, body)?;
let token = request
.extract_and_verify_token(state.redis.clone(), state.postgres_reader.clone())
.extract_and_verify_token(state.redis.clone(), state.reader.clone())
.await?;
let team = request
.get_team_from_cache_or_pg(&token, state.redis.clone(), state.postgres_reader.clone())
.get_team_from_cache_or_pg(&token, state.redis.clone(), state.reader.clone())
.await?;

let distinct_id = request.extract_distinct_id()?;
Expand All @@ -112,16 +112,16 @@ pub async fn process_request(context: RequestContext) -> Result<FlagsResponse, F
let hash_key_override = request.anon_distinct_id.clone();

let feature_flags_from_cache_or_pg = request
.get_flags_from_cache_or_pg(team_id, &state.redis, &state.postgres_reader)
.get_flags_from_cache_or_pg(team_id, &state.redis, &state.reader)
.await?;

let evaluation_context = FeatureFlagEvaluationContextBuilder::default()
.team_id(team_id)
.distinct_id(distinct_id)
.feature_flags(feature_flags_from_cache_or_pg)
.postgres_reader(state.postgres_reader.clone())
.postgres_writer(state.postgres_writer.clone())
.cohort_cache(state.cohort_cache.clone())
.reader(state.reader.clone())
.writer(state.writer.clone())
.cohort_cache(state.cohort_cache_manager.clone())
.person_property_overrides(person_property_overrides)
.group_property_overrides(group_property_overrides)
.groups(groups)
Expand Down Expand Up @@ -220,12 +220,12 @@ fn decode_request(headers: &HeaderMap, body: Bytes) -> Result<FlagRequest, FlagE
// which flags failed to evaluate
pub async fn evaluate_feature_flags(context: FeatureFlagEvaluationContext) -> FlagsResponse {
let group_type_mapping_cache =
GroupTypeMappingCache::new(context.team_id, context.postgres_reader.clone());
GroupTypeMappingCache::new(context.team_id, context.reader.clone());
let mut feature_flag_matcher = FeatureFlagMatcher::new(
context.distinct_id,
context.team_id,
context.postgres_reader,
context.postgres_writer,
context.reader,
context.writer,
context.cohort_cache,
Some(group_type_mapping_cache),
context.groups,
Expand Down Expand Up @@ -362,9 +362,9 @@ mod tests {

#[tokio::test]
async fn test_evaluate_feature_flags() {
let postgres_reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let postgres_writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(reader.clone(), None, None));
let flag = FeatureFlag {
name: Some("Test Flag".to_string()),
id: 1,
Expand Down Expand Up @@ -402,8 +402,8 @@ mod tests {
.team_id(1)
.distinct_id("user123".to_string())
.feature_flags(feature_flag_list)
.postgres_reader(postgres_reader)
.postgres_writer(postgres_writer)
.reader(reader)
.writer(writer)
.cohort_cache(cohort_cache)
.person_property_overrides(Some(person_properties))
.build()
Expand Down Expand Up @@ -511,9 +511,9 @@ mod tests {

#[tokio::test]
async fn test_evaluate_feature_flags_multiple_flags() {
let postgres_reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let postgres_writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(reader.clone(), None, None));
let flags = vec![
FeatureFlag {
name: Some("Flag 1".to_string()),
Expand Down Expand Up @@ -563,8 +563,8 @@ mod tests {
.team_id(1)
.distinct_id("user123".to_string())
.feature_flags(feature_flag_list)
.postgres_reader(postgres_reader)
.postgres_writer(postgres_writer)
.reader(reader)
.writer(writer)
.cohort_cache(cohort_cache)
.build()
.expect("Failed to build FeatureFlagEvaluationContext");
Expand Down Expand Up @@ -616,12 +616,10 @@ mod tests {

#[tokio::test]
async fn test_evaluate_feature_flags_with_overrides() {
let postgres_reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let postgres_writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let team = insert_new_team_in_pg(postgres_reader.clone(), None)
.await
.unwrap();
let reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(reader.clone(), None, None));
let team = insert_new_team_in_pg(reader.clone(), None).await.unwrap();

let flag = FeatureFlag {
name: Some("Test Flag".to_string()),
Expand Down Expand Up @@ -665,8 +663,8 @@ mod tests {
.team_id(team.id)
.distinct_id("user123".to_string())
.feature_flags(feature_flag_list)
.postgres_reader(postgres_reader)
.postgres_writer(postgres_writer)
.reader(reader)
.writer(writer)
.cohort_cache(cohort_cache)
.group_property_overrides(Some(group_property_overrides))
.groups(Some(groups))
Expand Down Expand Up @@ -699,9 +697,9 @@ mod tests {
#[tokio::test]
async fn test_long_distinct_id() {
let long_id = "a".repeat(1000);
let postgres_reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let postgres_writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(reader.clone(), None, None));
let flag = FeatureFlag {
name: Some("Test Flag".to_string()),
id: 1,
Expand Down Expand Up @@ -729,8 +727,8 @@ mod tests {
.team_id(1)
.distinct_id(long_id)
.feature_flags(feature_flag_list)
.postgres_reader(postgres_reader)
.postgres_writer(postgres_writer)
.reader(reader)
.writer(writer)
.cohort_cache(cohort_cache)
.build()
.expect("Failed to build FeatureFlagEvaluationContext");
Expand Down
84 changes: 44 additions & 40 deletions rust/feature-flags/src/cohort/cohort_cache_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use crate::api::errors::FlagError;
use crate::cohort::cohort_models::Cohort;
use crate::flags::flag_matching::{PostgresReader, TeamId};
use moka::future::Cache;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;

/// CohortCacheManager manages the in-memory cache of cohorts using `moka` for caching.
///
Expand All @@ -12,8 +14,8 @@ use std::time::Duration;
///
/// ```text
/// CohortCacheManager {
/// postgres_reader: PostgresReader,
/// per_team_cohorts: Cache<TeamId, Vec<Cohort>> {
/// reader: PostgresReader,
/// cache: Cache<TeamId, Vec<Cohort>> {
/// // Example:
/// 2: [
/// Cohort { id: 1, name: "Power Users", filters: {...} },
Expand All @@ -22,50 +24,59 @@ use std::time::Duration;
/// 5: [
/// Cohort { id: 3, name: "Beta Users", filters: {...} }
/// ]
/// }
/// },
/// fetch_lock: Mutex<()> // Manager-wide lock
/// }
/// ```
///
#[derive(Clone)]
pub struct CohortCacheManager {
postgres_reader: PostgresReader,
per_team_cohort_cache: Cache<TeamId, Vec<Cohort>>,
reader: PostgresReader,
cache: Cache<TeamId, Vec<Cohort>>,
fetch_lock: Arc<Mutex<()>>, // Added fetch_lock
}

impl CohortCacheManager {
pub fn new(
postgres_reader: PostgresReader,
reader: PostgresReader,
max_capacity: Option<u64>,
ttl_seconds: Option<u64>,
) -> Self {
// We use the size of the cohort list (i.e., the number of cohorts for a given team)as the weight of the entry
let weigher =
|_: &TeamId, value: &Vec<Cohort>| -> u32 { value.len().try_into().unwrap_or(u32::MAX) };
// We use the size of the cohort list (i.e., the number of cohorts for a given team) as the weight of the entry
let weigher = |_: &TeamId, value: &Vec<Cohort>| -> u32 { value.len() as u32 };

let cache = Cache::builder()
.time_to_live(Duration::from_secs(ttl_seconds.unwrap_or(300))) // Default to 5 minutes
.weigher(weigher)
.max_capacity(max_capacity.unwrap_or(10_000)) // Default to 10,000 cohorts
.max_capacity(max_capacity.unwrap_or(100_000)) // Default to 100,000 cohorts
.build();

Self {
postgres_reader,
per_team_cohort_cache: cache,
reader,
cache,
fetch_lock: Arc::new(Mutex::new(())), // Initialize the lock
}
}

/// Retrieves cohorts for a given team.
///
/// If the cohorts are not present in the cache or have expired, it fetches them from the database,
/// caches the result upon successful retrieval, and then returns it.
pub async fn get_cohorts_for_team(&self, team_id: TeamId) -> Result<Vec<Cohort>, FlagError> {
if let Some(cached_cohorts) = self.per_team_cohort_cache.get(&team_id).await {
pub async fn get_cohorts(&self, team_id: TeamId) -> Result<Vec<Cohort>, FlagError> {
if let Some(cached_cohorts) = self.cache.get(&team_id).await {
return Ok(cached_cohorts.clone());
}
let fetched_cohorts = Cohort::list_from_pg(self.postgres_reader.clone(), team_id).await?;
self.per_team_cohort_cache
.insert(team_id, fetched_cohorts.clone())
.await;

// Acquire the lock before fetching
let _lock = self.fetch_lock.lock().await;

// Double-check the cache after acquiring the lock
if let Some(cached_cohorts) = self.cache.get(&team_id).await {
return Ok(cached_cohorts.clone());
}

let fetched_cohorts = Cohort::list_from_pg(self.reader.clone(), team_id).await?;
self.cache.insert(team_id, fetched_cohorts.clone()).await;

Ok(fetched_cohorts)
}
Expand Down Expand Up @@ -116,18 +127,18 @@ mod tests {
Some(1), // 1-second TTL
);

let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?;
let cohorts = cohort_cache.get_cohorts(team_id).await?;
assert_eq!(cohorts.len(), 1);
assert_eq!(cohorts[0].team_id, team_id);

let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await;
let cached_cohorts = cohort_cache.cache.get(&team_id).await;
assert!(cached_cohorts.is_some());

// Wait for TTL to expire
sleep(Duration::from_secs(2)).await;

// Attempt to retrieve from cache again
let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await;
let cached_cohorts = cohort_cache.cache.get(&team_id).await;
assert!(cached_cohorts.is_none(), "Cache entry should have expired");

Ok(())
Expand All @@ -152,11 +163,11 @@ mod tests {
let team_id = team.id;
inserted_team_ids.push(team_id);
setup_test_cohort(writer_client.clone(), team_id, None).await?;
cohort_cache.get_cohorts_for_team(team_id).await?;
cohort_cache.get_cohorts(team_id).await?;
}

cohort_cache.per_team_cohort_cache.run_pending_tasks().await;
let cache_size = cohort_cache.per_team_cohort_cache.entry_count();
cohort_cache.cache.run_pending_tasks().await;
let cache_size = cohort_cache.cache.entry_count();
assert_eq!(
cache_size, max_capacity,
"Cache size should be equal to max_capacity"
Expand All @@ -165,26 +176,23 @@ mod tests {
let new_team = insert_new_team_in_pg(writer_client.clone(), None).await?;
let new_team_id = new_team.id;
setup_test_cohort(writer_client.clone(), new_team_id, None).await?;
cohort_cache.get_cohorts_for_team(new_team_id).await?;
cohort_cache.get_cohorts(new_team_id).await?;

cohort_cache.per_team_cohort_cache.run_pending_tasks().await;
let cache_size_after = cohort_cache.per_team_cohort_cache.entry_count();
cohort_cache.cache.run_pending_tasks().await;
let cache_size_after = cohort_cache.cache.entry_count();
assert_eq!(
cache_size_after, max_capacity,
"Cache size should remain equal to max_capacity after eviction"
);

let evicted_team_id = &inserted_team_ids[0];
let cached_cohorts = cohort_cache
.per_team_cohort_cache
.get(evicted_team_id)
.await;
let cached_cohorts = cohort_cache.cache.get(evicted_team_id).await;
assert!(
cached_cohorts.is_none(),
"Least recently used cache entry should have been evicted"
);

let cached_new_team = cohort_cache.per_team_cohort_cache.get(&new_team_id).await;
let cached_new_team = cohort_cache.cache.get(&new_team_id).await;
assert!(
cached_new_team.is_some(),
"Newly added cache entry should be present"
Expand All @@ -194,25 +202,21 @@ mod tests {
}

#[tokio::test]
async fn test_get_cohorts_for_team() -> Result<(), anyhow::Error> {
async fn test_get_cohorts() -> Result<(), anyhow::Error> {
let writer_client = setup_pg_writer_client(None).await;
let reader_client = setup_pg_reader_client(None).await;
let team_id = setup_test_team(writer_client.clone()).await?;
let _cohort = setup_test_cohort(writer_client.clone(), team_id, None).await?;
let cohort_cache = CohortCacheManager::new(reader_client.clone(), None, None);

let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await;
let cached_cohorts = cohort_cache.cache.get(&team_id).await;
assert!(cached_cohorts.is_none(), "Cache should initially be empty");

let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?;
let cohorts = cohort_cache.get_cohorts(team_id).await?;
assert_eq!(cohorts.len(), 1);
assert_eq!(cohorts[0].team_id, team_id);

let cached_cohorts = cohort_cache
.per_team_cohort_cache
.get(&team_id)
.await
.unwrap();
let cached_cohorts = cohort_cache.cache.get(&team_id).await.unwrap();
assert_eq!(cached_cohorts.len(), 1);
assert_eq!(cached_cohorts[0].team_id, team_id);

Expand Down
Loading

0 comments on commit f6de867

Please sign in to comment.