Skip to content

Commit

Permalink
Refactor read context to separate endpoints into clusters and
Browse files Browse the repository at this point in the history
destinations

This removes the biggest clone that we do on every packet request, which
should mean that we are much more efficient.
  • Loading branch information
XAMPPRocky committed Oct 22, 2023
1 parent 1da557a commit 7730eeb
Show file tree
Hide file tree
Showing 15 changed files with 121 additions and 70 deletions.
11 changes: 7 additions & 4 deletions src/cli/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,17 +403,20 @@ impl DownstreamReceiveWorkerConfig {
config: &Arc<Config>,
sessions: &Arc<SessionPool>,
) -> Result<usize, PipelineError> {
let endpoints: Vec<_> = config.clusters.read().endpoints().collect();
if endpoints.is_empty() {
if config.clusters.read().endpoints().count() == 0 {
return Err(PipelineError::NoUpstreamEndpoints);
}

let filters = config.filters.load();
let mut context = ReadContext::new(endpoints, packet.source.into(), packet.contents);
let mut context = ReadContext::new(
config.clusters.clone_value(),
packet.source.into(),
packet.contents,
);
filters.read(&mut context).await?;
let mut bytes_written = 0;

for endpoint in context.endpoints.iter() {
for endpoint in context.destinations.iter() {
sessions::ADDRESS_MAP.get(&endpoint.address);
let session_key = SessionKey {
source: packet.source,
Expand Down
4 changes: 4 additions & 0 deletions src/config/watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ impl<T: Clone> Watch<T> {
pub fn watch(&self) -> watch::Receiver<T> {
self.watchers.subscribe()
}

pub fn clone_value(&self) -> std::sync::Arc<T> {
self.value.clone()
}
}

impl<T: Clone + PartialEq + std::fmt::Debug> Watch<T> {
Expand Down
11 changes: 7 additions & 4 deletions src/filters/capture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ struct NoValueCaptured;
#[cfg(test)]
mod tests {
use crate::{
cluster::ClusterMap,
filters::metadata::CAPTURED_BYTES,
net::endpoint::{metadata::Value, Endpoint},
test::assert_write_no_change,
Expand Down Expand Up @@ -159,10 +160,11 @@ mod tests {
}),
};
let filter = Capture::from_config(config.into());
let endpoints = vec![Endpoint::new("127.0.0.1:81".parse().unwrap())];
let endpoints = ClusterMap::default();
endpoints.insert_default([Endpoint::new("127.0.0.1:81".parse().unwrap())].into());
assert!(filter
.read(&mut ReadContext::new(
endpoints,
endpoints.into(),
(std::net::Ipv4Addr::LOCALHOST, 80).into(),
"abc".to_string().into_bytes(),
))
Expand Down Expand Up @@ -235,9 +237,10 @@ mod tests {
where
F: Filter + ?Sized,
{
let endpoints = vec![Endpoint::new("127.0.0.1:81".parse().unwrap())];
let endpoints = ClusterMap::default();
endpoints.insert_default([Endpoint::new("127.0.0.1:81".parse().unwrap())].into());
let mut context = ReadContext::new(
endpoints,
endpoints.into(),
"127.0.0.1:80".parse().unwrap(),
"helloabc".to_string().into_bytes(),
);
Expand Down
28 changes: 19 additions & 9 deletions src/filters/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,13 @@ impl Filter for FilterChain {
}
}

// Special case to handle to allow for pass-through, if no filter
// has rejected, and the destinations is empty, we passthrough to all.
// Which mimics the old behaviour while avoid clones in most cases.
if ctx.destinations.is_empty() {
ctx.destinations = ctx.clusters.endpoints().collect();
}

Ok(())
}

Expand Down Expand Up @@ -340,28 +347,31 @@ mod tests {
assert!(result.is_err());
}

fn endpoints() -> Vec<Endpoint> {
vec![
fn endpoints() -> (crate::cluster::ClusterMap, Vec<Endpoint>) {
let clusters = crate::cluster::ClusterMap::default();
let endpoints = [
Endpoint::new("127.0.0.1:80".parse().unwrap()),
Endpoint::new("127.0.0.1:90".parse().unwrap()),
]
];
clusters.insert_default(endpoints.clone().into());
(clusters, endpoints.into())
}

#[tokio::test]
async fn chain_single_test_filter() {
crate::test::load_test_filters();
let config = new_test_config();
let endpoints_fixture = endpoints();
let (clusters, endpoints_fixture) = endpoints();
let mut context = ReadContext::new(
endpoints_fixture.clone(),
clusters.into(),
"127.0.0.1:70".parse().unwrap(),
b"hello".to_vec(),
);

config.filters.read(&mut context).await.unwrap();
let expected = endpoints_fixture.clone();

assert_eq!(expected, &*context.endpoints);
assert_eq!(expected, &*context.destinations);
assert_eq!(b"hello:odr:127.0.0.1:70", &*context.contents);
assert_eq!(
"receive",
Expand Down Expand Up @@ -396,16 +406,16 @@ mod tests {
])
.unwrap();

let endpoints_fixture = endpoints();
let (clusters, endpoints_fixture) = endpoints();
let mut context = ReadContext::new(
endpoints_fixture.clone(),
clusters.into(),
"127.0.0.1:70".parse().unwrap(),
b"hello".to_vec(),
);

chain.read(&mut context).await.unwrap();
let expected = endpoints_fixture.clone();
assert_eq!(expected, context.endpoints.to_vec());
assert_eq!(expected, context.destinations.to_vec());
assert_eq!(
b"hello:odr:127.0.0.1:70:odr:127.0.0.1:70",
&*context.contents
Expand Down
16 changes: 12 additions & 4 deletions src/filters/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,10 @@ mod tests {
let expected = contents_fixture();

// read compress
let endpoints = crate::cluster::ClusterMap::default();
endpoints.insert_default([Endpoint::new("127.0.0.1:81".parse().unwrap())].into());
let mut read_context = ReadContext::new(
vec![Endpoint::new("127.0.0.1:80".parse().unwrap())],
endpoints.into(),
"127.0.0.1:8080".parse().unwrap(),
expected.clone(),
);
Expand Down Expand Up @@ -238,9 +240,11 @@ mod tests {
Metrics::new(),
);

let endpoints = crate::cluster::ClusterMap::default();
endpoints.insert_default([Endpoint::new("127.0.0.1:81".parse().unwrap())].into());
assert!(compression
.read(&mut ReadContext::new(
vec![Endpoint::new("127.0.0.1:80".parse().unwrap())],
endpoints.into(),
"127.0.0.1:8080".parse().unwrap(),
b"hello".to_vec(),
))
Expand All @@ -259,8 +263,10 @@ mod tests {
Metrics::new(),
);

let endpoints = crate::cluster::ClusterMap::default();
endpoints.insert_default([Endpoint::new("127.0.0.1:81".parse().unwrap())].into());
let mut read_context = ReadContext::new(
vec![Endpoint::new("127.0.0.1:80".parse().unwrap())],
endpoints.into(),
"127.0.0.1:8080".parse().unwrap(),
b"hello".to_vec(),
);
Expand Down Expand Up @@ -345,8 +351,10 @@ mod tests {
);

// read decompress
let endpoints = crate::cluster::ClusterMap::default();
endpoints.insert_default([Endpoint::new("127.0.0.1:81".parse().unwrap())].into());
let mut read_context = ReadContext::new(
vec![Endpoint::new("127.0.0.1:80".parse().unwrap())],
endpoints.into(),
"127.0.0.1:8080".parse().unwrap(),
write_context.contents.clone(),
);
Expand Down
16 changes: 6 additions & 10 deletions src/filters/firewall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,14 @@ mod tests {
};

let local_ip = [192, 168, 75, 20];
let mut ctx = ReadContext::new(
vec![Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())],
(local_ip, 80).into(),
vec![],
);
let endpoints = crate::cluster::ClusterMap::default();
endpoints.insert_default([Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())].into());
let mut ctx = ReadContext::new(endpoints.into(), (local_ip, 80).into(), vec![]);
assert!(firewall.read(&mut ctx).await.is_ok());

let mut ctx = ReadContext::new(
vec![Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())],
(local_ip, 2000).into(),
vec![],
);
let endpoints = crate::cluster::ClusterMap::default();
endpoints.insert_default([Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())].into());
let mut ctx = ReadContext::new(endpoints.into(), (local_ip, 2000).into(), vec![]);
assert!(logs_contain("quilkin::filters::firewall")); // the given name to the the logger by tracing
assert!(logs_contain("Allow"));

Expand Down
10 changes: 4 additions & 6 deletions src/filters/load_balancer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,14 @@ mod tests {
input_addresses: &[EndpointAddress],
source: EndpointAddress,
) -> Vec<EndpointAddress> {
let mut context = ReadContext::new(
Vec::from_iter(input_addresses.iter().cloned().map(Endpoint::new)),
source,
vec![],
);
let endpoints = crate::cluster::ClusterMap::default();
endpoints.insert_default(input_addresses.iter().cloned().map(Endpoint::new).collect());
let mut context = ReadContext::new(endpoints.into(), source, vec![]);

filter.read(&mut context).await.unwrap();

context
.endpoints
.destinations
.iter()
.map(|ep| ep.address.clone())
.collect::<Vec<_>>()
Expand Down
18 changes: 14 additions & 4 deletions src/filters/load_balancer/endpoint_chooser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ impl EndpointChooser for RoundRobinEndpointChooser {
fn choose_endpoints(&self, ctx: &mut ReadContext) {
let count = self.next_endpoint.fetch_add(1, Ordering::Relaxed);
// Note: The index is guaranteed to be in range.
ctx.endpoints = vec![ctx.endpoints[count % ctx.endpoints.len()].clone()];
ctx.destinations = vec![ctx
.clusters
.endpoints()
.nth(count % ctx.clusters.endpoints().count())
.unwrap()
.clone()];
}
}

Expand All @@ -58,8 +63,8 @@ pub struct RandomEndpointChooser;
impl EndpointChooser for RandomEndpointChooser {
fn choose_endpoints(&self, ctx: &mut ReadContext) {
// The index is guaranteed to be in range.
let index = thread_rng().gen_range(0..ctx.endpoints.len());
ctx.endpoints = vec![ctx.endpoints[index].clone()];
let index = thread_rng().gen_range(0..ctx.clusters.endpoints().count());
ctx.destinations = vec![ctx.clusters.endpoints().nth(index).unwrap().clone()];
}
}

Expand All @@ -70,6 +75,11 @@ impl EndpointChooser for HashEndpointChooser {
fn choose_endpoints(&self, ctx: &mut ReadContext) {
let mut hasher = DefaultHasher::new();
ctx.source.hash(&mut hasher);
ctx.endpoints = vec![ctx.endpoints[hasher.finish() as usize % ctx.endpoints.len()].clone()];
ctx.destinations = vec![ctx
.clusters
.endpoints()
.nth(hasher.finish() as usize % ctx.clusters.endpoints().count())
.unwrap()
.clone()];
}
}
14 changes: 9 additions & 5 deletions src/filters/local_rate_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,15 @@ mod tests {

/// Send a packet to the filter and assert whether or not it was processed.
async fn read(r: &LocalRateLimit, address: &EndpointAddress, should_succeed: bool) {
let endpoints = vec![crate::net::endpoint::Endpoint::new(
(Ipv4Addr::LOCALHOST, 8089).into(),
)];

let mut context = ReadContext::new(endpoints, address.clone(), vec![9]);
let endpoints = crate::cluster::ClusterMap::default();
endpoints.insert_default(
[crate::net::endpoint::Endpoint::new(
(Ipv4Addr::LOCALHOST, 8089).into(),
)]
.into(),
);

let mut context = ReadContext::new(endpoints.into(), address.clone(), vec![9]);
let result = r.read(&mut context).await;

if should_succeed {
Expand Down
12 changes: 6 additions & 6 deletions src/filters/match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,10 @@ mod tests {
assert_eq!(0, filter.metrics.packets_matched_total.get());

// config so we can test match and fallthrough.
let endpoints = crate::cluster::ClusterMap::default();
endpoints.insert_default([Endpoint::new("127.0.0.1:81".parse().unwrap())].into());
let mut ctx = ReadContext::new(
vec![Default::default()],
endpoints.into(),
([127, 0, 0, 1], 7000).into(),
contents.clone(),
);
Expand All @@ -216,11 +218,9 @@ mod tests {
assert_eq!(1, filter.metrics.packets_matched_total.get());
assert_eq!(0, filter.metrics.packets_fallthrough_total.get());

let mut ctx = ReadContext::new(
vec![Default::default()],
([127, 0, 0, 1], 7000).into(),
contents,
);
let endpoints = crate::cluster::ClusterMap::default();
endpoints.insert_default([Endpoint::new("127.0.0.1:81".parse().unwrap())].into());
let mut ctx = ReadContext::new(endpoints.into(), ([127, 0, 0, 1], 7000).into(), contents);
ctx.metadata.insert(key, "xyz".into());

let result = filter.read(&mut ctx).await;
Expand Down
16 changes: 12 additions & 4 deletions src/filters/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,22 @@
* limitations under the License.
*/

use std::sync::Arc;

#[cfg(doc)]
use crate::filters::Filter;
use crate::net::endpoint::{metadata::DynamicMetadata, Endpoint, EndpointAddress};
use crate::{
cluster::ClusterMap,
net::endpoint::{metadata::DynamicMetadata, Endpoint, EndpointAddress},
};

/// The input arguments to [`Filter::read`].
#[non_exhaustive]
pub struct ReadContext {
/// The upstream endpoints that the packet will be forwarded to.
pub endpoints: Vec<Endpoint>,
pub clusters: Arc<ClusterMap>,
/// The upstream endpoints that the packet will be forwarded to.
pub destinations: Vec<Endpoint>,
/// The source of the received packet.
pub source: EndpointAddress,
/// Contents of the received packet.
Expand All @@ -33,9 +40,10 @@ pub struct ReadContext {

impl ReadContext {
/// Creates a new [`ReadContext`].
pub fn new(endpoints: Vec<Endpoint>, source: EndpointAddress, contents: Vec<u8>) -> Self {
pub fn new(clusters: Arc<ClusterMap>, source: EndpointAddress, contents: Vec<u8>) -> Self {
Self {
endpoints,
clusters,
destinations: Vec::new(),
source,
contents,
metadata: DynamicMetadata::new(),
Expand Down
8 changes: 3 additions & 5 deletions src/filters/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,10 @@ mod tests {
let addr: EndpointAddress = (Ipv4Addr::LOCALHOST, 8080).into();
let endpoint = Endpoint::new(addr.clone());

let clusters = crate::cluster::ClusterMap::default();
clusters.insert_default([endpoint.clone()].into());
assert!(filter
.read(&mut ReadContext::new(
vec![endpoint.clone()],
addr.clone(),
vec![]
))
.read(&mut ReadContext::new(clusters.into(), addr.clone(), vec![]))
.await
.is_ok());
assert!(filter
Expand Down
6 changes: 4 additions & 2 deletions src/filters/timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,9 @@ mod tests {
async fn basic() {
const TIMESTAMP_KEY: &str = "BASIC";
let filter = Timestamp::from_config(Config::new(TIMESTAMP_KEY).into());
let endpoints = crate::cluster::ClusterMap::default();
let mut ctx = ReadContext::new(
vec![],
endpoints.into(),
(std::net::Ipv4Addr::UNSPECIFIED, 0).into(),
b"hello".to_vec(),
);
Expand Down Expand Up @@ -199,8 +200,9 @@ mod tests {
);
let timestamp = Timestamp::from_config(Config::new(TIMESTAMP_KEY).into());
let source = (std::net::Ipv4Addr::UNSPECIFIED, 0);
let endpoints = crate::cluster::ClusterMap::default();
let mut ctx = ReadContext::new(
vec![],
endpoints.into(),
source.into(),
[0, 0, 0, 0, 99, 81, 55, 181].to_vec(),
);
Expand Down
Loading

0 comments on commit 7730eeb

Please sign in to comment.