diff --git a/Cargo.toml b/Cargo.toml index d90dd1807e..5bc96bbbff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,6 +58,11 @@ name = "misc" harness = false test = false +[[bench]] +name = "token_router" +harness = false +test = false + [dependencies] # Local quilkin-macros = { version = "0.9.0-dev", path = "./macros" } diff --git a/benches/cluster_map.rs b/benches/cluster_map.rs index 9ce2cd775a..d96ff46263 100644 --- a/benches/cluster_map.rs +++ b/benches/cluster_map.rs @@ -3,6 +3,8 @@ use quilkin::net::cluster::ClusterMap; mod shared; +use shared::TokenKind; + #[divan::bench_group(sample_count = 10)] mod serde { use super::*; @@ -66,21 +68,21 @@ mod serde { #[divan::bench(consts = SEEDS)] fn serialize_proto(b: Bencher) { - let gc = gen_cluster_map::(); + let gc = gen_cluster_map::(TokenKind::None); b.counter(gc.total_endpoints) .bench(|| divan::black_box(serialize_to_protobuf(&gc.cm))); } #[divan::bench(consts = SEEDS)] fn serialize_json(b: Bencher) { - let gc = gen_cluster_map::(); + let gc = gen_cluster_map::(TokenKind::None); b.counter(gc.total_endpoints) .bench(|| divan::black_box(serialize_to_json(&gc.cm))); } #[divan::bench(consts = SEEDS)] fn deserialize_json(b: Bencher) { - let gc = gen_cluster_map::(); + let gc = gen_cluster_map::(TokenKind::None); let json = serialize_to_json(&gc.cm); b.with_inputs(|| json.clone()) @@ -90,7 +92,7 @@ mod serde { #[divan::bench(consts = SEEDS)] fn deserialize_proto(b: Bencher) { - let gc = gen_cluster_map::(); + let gc = gen_cluster_map::(TokenKind::None); let pv = serialize_to_protobuf(&gc.cm); b.with_inputs(|| pv.clone()) @@ -125,7 +127,7 @@ mod ops { #[divan::bench(consts = SEEDS)] fn iterate(b: Bencher) { - let cm = gen_cluster_map::(); + let cm = gen_cluster_map::(TokenKind::None); b.counter(cm.total_endpoints) .bench_local(|| divan::black_box(compute_hash::(&cm))); @@ -135,7 +137,7 @@ mod ops { #[divan::bench(consts = SEEDS)] fn iterate_par(b: Bencher) { - let cm = gen_cluster_map::(); + let cm = gen_cluster_map::(TokenKind::None); b.counter(cm.total_endpoints) .bench(|| divan::black_box(compute_hash::(&cm))) diff --git a/benches/misc.rs b/benches/misc.rs index cfdfd5840a..5e001bd2cb 100644 --- a/benches/misc.rs +++ b/benches/misc.rs @@ -226,7 +226,7 @@ impl GenResource for Cluster { let mut rng = rand::rngs::SmallRng::seed_from_u64(self.counter as u64); let mut hasher = xxhash_rust::xxh3::Xxh3::new(); - let endpoints = shared::gen_endpoints(&mut rng, &mut hasher); + let endpoints = shared::gen_endpoints(&mut rng, &mut hasher, None); let msg = quilkin::generated::quilkin::config::v1alpha1::Cluster { locality: Some(quilkin::generated::quilkin::config::v1alpha1::Locality { diff --git a/benches/shared.rs b/benches/shared.rs index 4acabb8c24..e6748c19aa 100644 --- a/benches/shared.rs +++ b/benches/shared.rs @@ -497,11 +497,20 @@ pub const LOCALITIES: &[&str] = &[ "us:west4:c", ]; -pub fn gen_endpoints(rng: &mut rand::rngs::SmallRng, hasher: &mut Hasher) -> BTreeSet { +pub fn gen_endpoints( + rng: &mut rand::rngs::SmallRng, + hasher: &mut Hasher, + mut tg: Option<&mut TokenGenerator>, +) -> BTreeSet { let num_endpoints = rng.gen_range(100..10_000); hasher.write_u16(num_endpoints); let mut endpoints = BTreeSet::new(); + if let Some(tg) = &mut tg { + if let Some(prev) = &mut tg.previous { + prev.clear(); + } + } for i in 0..num_endpoints { let ep_addr = match i % 3 { @@ -514,7 +523,20 @@ pub fn gen_endpoints(rng: &mut rand::rngs::SmallRng, hasher: &mut Hasher) -> BTr _ => unreachable!(), }; - endpoints.insert(Endpoint::new(ep_addr)); + let ep = if let Some(tg) = &mut tg { + let set = tg.next().unwrap(); + + Endpoint::with_metadata( + ep_addr, + quilkin::net::endpoint::EndpointMetadata::new(quilkin::net::endpoint::Metadata { + tokens: set, + }), + ) + } else { + Endpoint::new(ep_addr) + }; + + endpoints.insert(ep); } for ep in &endpoints { @@ -541,7 +563,98 @@ fn write_locality(hasher: &mut Hasher, loc: &Option) { } } -pub fn gen_cluster_map() -> GenCluster { +pub enum TokenKind { + None, + Single { + duplicates: bool, + }, + Multi { + range: std::ops::Range, + duplicates: bool, + }, +} + +impl std::str::FromStr for TokenKind { + type Err = eyre::Error; + + fn from_str(s: &str) -> Result { + let dupes = |s: &str| match s { + "duplicates" => Ok(true), + "unique" => Ok(false), + _ => eyre::bail!("must be `duplicates` or `unique`"), + }; + + if let Some(rest) = s.strip_prefix("single:") { + Ok(Self::Single { + duplicates: dupes(rest)?, + }) + } else if let Some(rest) = s.strip_prefix("multi:") { + let (r, rest) = rest + .split_once(':') + .ok_or_else(|| eyre::format_err!("multi must specify 'range:duplicates'"))?; + + let (start, end) = r + .split_once("..") + .ok_or_else(|| eyre::format_err!("range must be specified as '..'"))?; + + let range = start.parse()?..end.parse()?; + + Ok(Self::Multi { + range, + duplicates: dupes(rest)?, + }) + } else { + eyre::bail!("unknown token kind"); + } + } +} + +pub struct TokenGenerator { + rng: rand::rngs::SmallRng, + previous: Option>>, + range: Option>, +} + +impl Iterator for TokenGenerator { + type Item = quilkin::net::endpoint::Set; + + fn next(&mut self) -> Option { + use rand::RngCore; + let mut set = Self::Item::new(); + + let count = if let Some(range) = self.range.clone() { + self.rng.gen_range(range) + } else { + 1 + }; + + if let Some(prev) = &mut self.previous { + for _ in 0..count { + if !prev.is_empty() && self.rng.gen_ratio(1, 10) { + let prev = &prev[self.rng.gen_range(0..prev.len())]; + set.insert(prev.clone()); + } else { + let count = self.rng.gen_range(4..20); + let mut v = vec![0; count]; + self.rng.fill_bytes(&mut v); + prev.push(v.clone()); + set.insert(v); + } + } + } else { + for _ in 0..count { + let count = self.rng.gen_range(4..20); + let mut v = vec![0; count]; + self.rng.fill_bytes(&mut v); + set.insert(v); + } + } + + Some(set) + } +} + +pub fn gen_cluster_map(token_kind: TokenKind) -> GenCluster { use rand::prelude::*; let mut rng = rand::rngs::SmallRng::seed_from_u64(S); @@ -566,10 +679,24 @@ pub fn gen_cluster_map() -> GenCluster { let keys: Vec<_> = cm.iter().map(|kv| kv.key().clone()).collect(); let mut sets = std::collections::BTreeMap::new(); + let mut token_generator = match token_kind { + TokenKind::None => None, + TokenKind::Multi { range, duplicates } => Some(TokenGenerator { + rng: rand::rngs::SmallRng::seed_from_u64(S), + previous: duplicates.then_some(Vec::new()), + range: Some(range), + }), + TokenKind::Single { duplicates } => Some(TokenGenerator { + rng: rand::rngs::SmallRng::seed_from_u64(S), + previous: duplicates.then_some(Vec::new()), + range: None, + }), + }; + for key in keys { write_locality(&mut hasher, &key); - let ep = gen_endpoints(&mut rng, &mut hasher); + let ep = gen_endpoints(&mut rng, &mut hasher, token_generator.as_mut()); total_endpoints += ep.len(); cm.insert(key.clone(), ep.clone()); sets.insert(key, ep); diff --git a/benches/token_router.rs b/benches/token_router.rs new file mode 100644 index 0000000000..5b599cd604 --- /dev/null +++ b/benches/token_router.rs @@ -0,0 +1,60 @@ +use divan::Bencher; +use quilkin::filters::token_router::{HashedTokenRouter, Router, TokenRouter}; +use rand::SeedableRng; + +mod shared; + +#[divan::bench(types = [TokenRouter, HashedTokenRouter], args = ["single:duplicates", "single:unique", "multi:2..128:duplicates", "multi:2..128:unique"])] +fn token_router(b: Bencher, token_kind: &str) +where + T: Router + Sync, +{ + let filter = ::new(); + let gc = shared::gen_cluster_map::<42>(token_kind.parse().unwrap()); + + let mut tokens = Vec::new(); + + let cm = std::sync::Arc::new(gc.cm); + cm.build_token_maps(); + + // Calculate the amount of bytes for all the tokens + for eps in cm.iter() { + for ep in &eps.value().endpoints { + for tok in &ep.metadata.known.tokens { + tokens.push(tok.clone()); + } + } + } + + let total_token_size: usize = tokens.iter().map(|t| t.len()).sum(); + let pool = std::sync::Arc::new(quilkin::pool::BufferPool::new(1, 1)); + + let mut rand = rand::rngs::SmallRng::seed_from_u64(42); + + b.with_inputs(|| { + use rand::seq::SliceRandom as _; + let tok = tokens.choose(&mut rand).unwrap(); + + let mut rc = quilkin::filters::ReadContext::new( + cm.clone(), + quilkin::net::EndpointAddress::LOCALHOST, + pool.clone().alloc(), + ); + rc.metadata.insert( + quilkin::net::endpoint::metadata::Key::from_static( + quilkin::filters::capture::CAPTURED_BYTES, + ), + quilkin::net::endpoint::metadata::Value::Bytes((*tok).clone().into()), + ); + + rc + }) + .counter(divan::counter::BytesCount::new(total_token_size)) + .bench_local_values(|mut rc| { + let _ = divan::black_box(filter.sync_read(&mut rc)); + }) +} + +fn main() { + divan::main(); +} diff --git a/src/components/proxy/packet_router.rs b/src/components/proxy/packet_router.rs index b7f4bdb7cd..86c5790772 100644 --- a/src/components/proxy/packet_router.rs +++ b/src/components/proxy/packet_router.rs @@ -228,10 +228,10 @@ impl DownstreamReceiveWorkerConfig { // cheaply and returned to the pool once all references are dropped let contents = contents.freeze(); - for endpoint in destinations.iter() { + for epa in destinations { let session_key = SessionKey { source: packet.source, - dest: endpoint.address.to_socket_addr().await?, + dest: epa.to_socket_addr().await?, }; sessions diff --git a/src/filters/chain.rs b/src/filters/chain.rs index dca64dc4be..0f5ee0cd1c 100644 --- a/src/filters/chain.rs +++ b/src/filters/chain.rs @@ -297,7 +297,12 @@ impl Filter for FilterChain { // has rejected, and the destinations is empty, we passthrough to all. // Which mimics the old behaviour while avoid clones in most cases. if ctx.destinations.is_empty() { - ctx.destinations = ctx.endpoints.endpoints(); + ctx.destinations = ctx + .endpoints + .endpoints() + .into_iter() + .map(|ep| ep.address) + .collect(); } Ok(()) diff --git a/src/filters/compress/proto.rs b/src/filters/compress/proto.rs deleted file mode 100644 index 5af8d9bf2a..0000000000 --- a/src/filters/compress/proto.rs +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/// Protobuf config for this filter. -pub(super) mod quilkin { - pub mod extensions { - pub mod filters { - pub mod compress { - pub mod v1alpha1 { - #![doc(hidden)] - tonic::include_proto!("quilkin.filters.compress.v1alpha1"); - } - } - } - } -} diff --git a/src/filters/load_balancer.rs b/src/filters/load_balancer.rs index 6f426e32e0..7e3287c015 100644 --- a/src/filters/load_balancer.rs +++ b/src/filters/load_balancer.rs @@ -80,11 +80,7 @@ mod tests { filter.read(&mut context).await.unwrap(); - context - .destinations - .iter() - .map(|ep| ep.address.clone()) - .collect::>() + context.destinations } #[tokio::test] diff --git a/src/filters/load_balancer/endpoint_chooser.rs b/src/filters/load_balancer/endpoint_chooser.rs index 028304ae55..14426dc3f3 100644 --- a/src/filters/load_balancer/endpoint_chooser.rs +++ b/src/filters/load_balancer/endpoint_chooser.rs @@ -52,6 +52,7 @@ impl EndpointChooser for RoundRobinEndpointChooser { .endpoints .nth_endpoint(count % ctx.endpoints.num_of_endpoints()) .unwrap() + .address .clone()]; } } @@ -63,7 +64,7 @@ impl EndpointChooser for RandomEndpointChooser { fn choose_endpoints(&self, ctx: &mut ReadContext) { // The index is guaranteed to be in range. let index = thread_rng().gen_range(0..ctx.endpoints.num_of_endpoints()); - ctx.destinations = vec![ctx.endpoints.nth_endpoint(index).unwrap().clone()]; + ctx.destinations = vec![ctx.endpoints.nth_endpoint(index).unwrap().address.clone()]; } } @@ -78,6 +79,7 @@ impl EndpointChooser for HashEndpointChooser { .endpoints .nth_endpoint(hasher.finish() as usize % ctx.endpoints.num_of_endpoints()) .unwrap() + .address .clone()]; } } diff --git a/src/filters/load_balancer/proto.rs b/src/filters/load_balancer/proto.rs deleted file mode 100644 index 012571c8e2..0000000000 --- a/src/filters/load_balancer/proto.rs +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/// Protobuf config for this filter. -pub(super) mod quilkin { - pub mod extensions { - pub mod filters { - pub mod load_balancer { - pub mod v1alpha1 { - #![doc(hidden)] - tonic::include_proto!("quilkin.filters.load_balancer.v1alpha1"); - } - } - } - } -} diff --git a/src/filters/local_rate_limit/proto.rs b/src/filters/local_rate_limit/proto.rs deleted file mode 100644 index eaa305ef2a..0000000000 --- a/src/filters/local_rate_limit/proto.rs +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/// Protobuf config for this filter. -pub(super) mod quilkin { - pub mod extensions { - pub mod filters { - pub mod local_rate_limit { - pub mod v1alpha1 { - #![doc(hidden)] - tonic::include_proto!("quilkin.filters.local_rate_limit.v1alpha1"); - } - } - } - } -} diff --git a/src/filters/read.rs b/src/filters/read.rs index 649809331e..d488829142 100644 --- a/src/filters/read.rs +++ b/src/filters/read.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use crate::filters::Filter; use crate::{ net::{ - endpoint::{metadata::DynamicMetadata, Endpoint, EndpointAddress}, + endpoint::{metadata::DynamicMetadata, EndpointAddress}, ClusterMap, }, pool::PoolBuffer, @@ -32,7 +32,7 @@ pub struct ReadContext { /// The upstream endpoints that the packet will be forwarded to. pub endpoints: Arc, /// The upstream endpoints that the packet will be forwarded to. - pub destinations: Vec, + pub destinations: Vec, /// The source of the received packet. pub source: EndpointAddress, /// Contents of the received packet. @@ -53,9 +53,4 @@ impl ReadContext { metadata: DynamicMetadata::new(), } } - - pub fn metadata(mut self, metadata: DynamicMetadata) -> Self { - self.metadata = metadata; - self - } } diff --git a/src/filters/token_router.rs b/src/filters/token_router.rs index 1336b68704..3e56163246 100644 --- a/src/filters/token_router.rs +++ b/src/filters/token_router.rs @@ -48,9 +48,15 @@ impl StaticFilter for TokenRouter { #[async_trait::async_trait] impl Filter for TokenRouter { async fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> { + self.sync_read(ctx) + } +} + +impl Router for TokenRouter { + fn sync_read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> { match ctx.metadata.get(&self.config.metadata_key) { Some(metadata::Value::Bytes(token)) => { - ctx.destinations = ctx.endpoints.filter_endpoints(|endpoint| { + let destinations = ctx.endpoints.filter_endpoints(|endpoint| { if endpoint.metadata.known.tokens.contains(&**token) { tracing::trace!(%endpoint.address, token = &*crate::codec::base64::encode(token), "Endpoint matched"); true @@ -59,6 +65,8 @@ impl Filter for TokenRouter { } }); + ctx.destinations = destinations.into_iter().map(|ep| ep.address).collect(); + if ctx.destinations.is_empty() { Err(FilterError::new(Error::NoEndpointMatch( self.config.metadata_key, @@ -77,6 +85,84 @@ impl Filter for TokenRouter { ))), } } + + fn new() -> Self { + Self::from_config(None) + } +} + +pub struct HashedTokenRouter { + config: Config, +} + +impl HashedTokenRouter { + fn new(config: Config) -> Self { + Self { config } + } +} + +impl StaticFilter for HashedTokenRouter { + const NAME: &'static str = "quilkin.filters.token_router.v1alpha1.HashedTokenRouter"; + type Configuration = Config; + type BinaryConfiguration = proto::TokenRouter; + + fn try_from_config(config: Option) -> Result { + Ok(Self::new(config.unwrap_or_default())) + } +} + +#[async_trait::async_trait] +impl Filter for HashedTokenRouter { + async fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> { + self.sync_read(ctx) + } +} + +impl Router for HashedTokenRouter { + fn sync_read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> { + match ctx.metadata.get(&self.config.metadata_key) { + Some(metadata::Value::Bytes(token)) => { + let mut destinations = Vec::new(); + + let tok = crate::net::cluster::Token::new(token); + + for ep in ctx.endpoints.iter() { + ep.value().addresses_for_token(tok, &mut destinations); + + if !destinations.is_empty() { + break; + } + } + + ctx.destinations = destinations; + + if ctx.destinations.is_empty() { + Err(FilterError::new(Error::NoEndpointMatch( + self.config.metadata_key, + crate::codec::base64::encode(token), + ))) + } else { + Ok(()) + } + } + Some(value) => Err(FilterError::new(Error::InvalidType( + self.config.metadata_key, + value.clone(), + ))), + None => Err(FilterError::new(Error::NoTokenFound( + self.config.metadata_key, + ))), + } + } + + fn new() -> Self { + Self::from_config(None) + } +} + +pub trait Router { + fn sync_read(&self, ctx: &mut ReadContext) -> Result<(), FilterError>; + fn new() -> Self; } #[derive(Debug, thiserror::Error)] diff --git a/src/filters/token_router/proto.rs b/src/filters/token_router/proto.rs deleted file mode 100644 index e6f72733fb..0000000000 --- a/src/filters/token_router/proto.rs +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/// Protobuf config for this filter. -pub(super) mod quilkin { - pub mod extensions { - pub mod filters { - pub mod token_router { - pub mod v1alpha1 { - #![doc(hidden)] - tonic::include_proto!("quilkin.filters.token_router.v1alpha1"); - } - } - } - } -} diff --git a/src/net/cluster.rs b/src/net/cluster.rs index 4024109653..d3ee6ad271 100644 --- a/src/net/cluster.rs +++ b/src/net/cluster.rs @@ -24,7 +24,7 @@ use dashmap::DashMap; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; -use crate::net::endpoint::{Endpoint, Locality}; +use crate::net::endpoint::{Endpoint, EndpointAddress, Locality}; const SUBSYSTEM: &str = "cluster"; @@ -84,9 +84,22 @@ impl std::str::FromStr for EndpointSetVersion { } } +pub type TokenAddressMap = std::collections::BTreeMap>; + +#[derive(Copy, Clone)] +pub struct Token(u64); + +impl Token { + #[inline] + pub fn new(token: &[u8]) -> Self { + Self(seahash::hash(token)) + } +} + #[derive(Debug, Clone)] pub struct EndpointSet { pub endpoints: BTreeSet, + pub token_map: TokenAddressMap, /// The hash of all of the endpoints in this set hash: u64, /// Version of this set of endpoints. Any mutatation of the endpoints @@ -100,6 +113,7 @@ impl EndpointSet { pub fn new(endpoints: BTreeSet) -> Self { let mut this = Self { endpoints, + token_map: TokenAddressMap::new(), hash: 0, version: 0, }; @@ -115,11 +129,15 @@ impl EndpointSet { /// across machines #[inline] pub fn with_version(endpoints: BTreeSet, hash: EndpointSetVersion) -> Self { - Self { + let mut this = Self { endpoints, + token_map: TokenAddressMap::new(), hash: hash.0, version: 1, - } + }; + + this.build_token_map(); + this } #[inline] @@ -137,6 +155,13 @@ impl EndpointSet { self.endpoints.contains(ep) } + #[inline] + pub fn addresses_for_token(&self, token: Token, addresses: &mut Vec) { + if let Some(addrs) = self.token_map.get(&token.0) { + addresses.extend_from_slice(addrs); + } + } + /// Unique version for this endpoint set #[inline] pub fn version(&self) -> EndpointSetVersion { @@ -159,6 +184,22 @@ impl EndpointSet { self.version += 1; } + /// Creates a map of tokens -> address for the current set + #[inline] + pub fn build_token_map(&mut self) { + let mut token_map = TokenAddressMap::new(); + + // This is only called on proxies, so calculate a token map + for ep in &self.endpoints { + for tok in &ep.metadata.known.tokens { + let hash = seahash::hash(tok); + token_map.entry(hash).or_default().push(ep.address.clone()); + } + } + + self.token_map = token_map; + } + #[inline] pub fn replace(&mut self, replacement: Self) -> BTreeSet { let old = std::mem::replace(&mut self.endpoints, replacement.endpoints); @@ -170,6 +211,10 @@ impl EndpointSet { self.version += 1; } + if !self.token_map.is_empty() { + self.build_token_map(); + } + old } } @@ -419,6 +464,14 @@ where ret } + + /// Builds token maps for every locality. Only used by testing/benching + #[doc(hidden)] + pub fn build_token_maps(&self) { + for mut eps in self.map.iter_mut() { + eps.build_token_map(); + } + } } impl crate::config::watch::Watchable for ClusterMap { diff --git a/src/net/endpoint.rs b/src/net/endpoint.rs index 1a69c09f93..73a78dc1eb 100644 --- a/src/net/endpoint.rs +++ b/src/net/endpoint.rs @@ -31,6 +31,7 @@ pub use self::{ }; pub type EndpointMetadata = metadata::MetadataView; +pub use base64_set::Set; /// A destination endpoint with any associated metadata. #[derive(Debug, Deserialize, Serialize, PartialEq, Clone, Eq, schemars::JsonSchema)] @@ -298,36 +299,50 @@ pub enum MetadataError { mod base64_set { use serde::de::Error; - pub type Set> = std::collections::BTreeSet; + pub type Set = std::collections::BTreeSet>; pub fn serialize(set: &Set, ser: S) -> Result where S: serde::Serializer, { - serde::Serialize::serialize( - &set.iter() - .map(crate::codec::base64::encode) - .collect::>(), - ser, - ) + ser.collect_seq(set.iter().map(crate::codec::base64::encode)) } pub fn deserialize<'de, D>(de: D) -> Result where D: serde::Deserializer<'de>, { - let items = as serde::Deserialize>::deserialize(de)?; - let set = items.iter().cloned().collect::>(); + struct TokenVisitor; - if set.len() != items.len() { - Err(D::Error::custom( - "Found duplicate tokens in endpoint metadata.", - )) - } else { - set.into_iter() - .map(|string| crate::codec::base64::decode(string).map_err(D::Error::custom)) - .collect() + impl<'de> serde::de::Visitor<'de> for TokenVisitor { + type Value = Set; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("an array of base64 encoded tokens") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut set = Set::new(); + + while let Some(token) = seq.next_element::>()? { + let decoded = + crate::codec::base64::decode(token.as_ref()).map_err(Error::custom)?; + + if !set.insert(decoded) { + return Err(Error::custom( + "Found duplicate tokens in endpoint metadata.", + )); + } + } + + Ok(set) + } } + + de.deserialize_seq(TokenVisitor) } }