Skip to content

Commit

Permalink
feat(s2n-quic-core): add task cooldown module (#1862)
Browse files Browse the repository at this point in the history
* feat(s2n-quic-core): add task cooldown module

* add cooldown disabled test
  • Loading branch information
camshaft authored Aug 11, 2023
1 parent 83c041d commit 2b5cb23
Show file tree
Hide file tree
Showing 26 changed files with 500 additions and 136 deletions.
5 changes: 5 additions & 0 deletions quic/s2n-quic-core/src/io/event_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
endpoint::Endpoint,
event::{self, EndpointPublisher},
io::{rx::Rx, tx::Tx},
task::cooldown::Cooldown,
time::clock::{ClockWithTimer, Timer},
};
use core::pin::Pin;
Expand All @@ -17,6 +18,7 @@ pub struct EventLoop<E, C, R, T> {
pub clock: C,
pub rx: R,
pub tx: T,
pub cooldown: Cooldown,
}

impl<E, C, R, T> EventLoop<E, C, R, T>
Expand All @@ -33,6 +35,7 @@ where
clock,
mut rx,
mut tx,
mut cooldown,
} = self;

/// Creates a event publisher with the endpoint's subscriber
Expand Down Expand Up @@ -78,6 +81,8 @@ where
// Concurrently poll all of the futures and wake up on the first one that's ready
let select = Select::new(rx_ready, tx_ready, wakeups, timer_ready);

let select = cooldown.wrap(select);

let select::Outcome {
rx_result,
tx_result,
Expand Down
1 change: 1 addition & 0 deletions quic/s2n-quic-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ pub mod slice;
pub mod stateless_reset;
pub mod stream;
pub mod sync;
pub mod task;
pub mod time;
pub mod token;
pub mod transmission;
Expand Down
4 changes: 4 additions & 0 deletions quic/s2n-quic-core/src/task.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

pub mod cooldown;
178 changes: 178 additions & 0 deletions quic/s2n-quic-core/src/task/cooldown.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use core::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use pin_project_lite::pin_project;

#[derive(Clone, Debug, Default)]
pub struct Cooldown {
credits: u16,
limit: u16,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Outcome {
/// The task should loop
Loop,
/// The task should return Pending and wait for an actual wake notification
Sleep,
}

impl Outcome {
#[inline]
pub fn is_loop(&self) -> bool {
matches!(self, Self::Loop)
}

#[inline]
pub fn is_sleep(&self) -> bool {
matches!(self, Self::Sleep)
}
}

impl Cooldown {
#[inline]
pub fn new(limit: u16) -> Self {
Self {
limit,
credits: limit,
}
}

#[inline]
pub fn state(&self) -> Outcome {
if self.credits > 0 {
Outcome::Loop
} else {
Outcome::Sleep
}
}

/// Notifies the cooldown that the poll operation was ready
///
/// This resets the cooldown period until another `Pending` result.
#[inline]
pub fn on_ready(&mut self) {
// reset the pending count
self.credits = self.limit;
}

/// Notifies the cooldown that the poll operation was pending
///
/// This consumes a cooldown credit until they are exhausted at which point the task should
/// sleep.
#[inline]
pub fn on_pending(&mut self) -> Outcome {
if self.credits > 0 {
self.credits -= 1;
return Outcome::Loop;
}

Outcome::Sleep
}

#[inline]
pub fn on_pending_task(&mut self, cx: &mut core::task::Context) -> Outcome {
let outcome = self.on_pending();

if outcome.is_loop() {
cx.waker().wake_by_ref();
}

outcome
}

#[inline]
pub async fn wrap<F>(&mut self, fut: F) -> F::Output
where
F: Future + Unpin,
{
Wrapped {
fut,
cooldown: self,
}
.await
}
}

pin_project!(
struct Wrapped<'a, F>
where
F: core::future::Future,
{
#[pin]
fut: F,
cooldown: &'a mut Cooldown,
}
);

impl<'a, F> Future for Wrapped<'a, F>
where
F: Future,
{
type Output = F::Output;

#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
match this.fut.poll(cx) {
Poll::Ready(v) => {
this.cooldown.on_ready();
Poll::Ready(v)
}
Poll::Pending => {
this.cooldown.on_pending_task(cx);
Poll::Pending
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn cooldown_test() {
let mut cooldown = Cooldown::new(2);

assert_eq!(cooldown.on_pending(), Outcome::Loop);
assert_eq!(cooldown.on_pending(), Outcome::Loop);
assert_eq!(cooldown.on_pending(), Outcome::Sleep);
assert_eq!(cooldown.on_pending(), Outcome::Sleep);

// call on ready to restore credits
cooldown.on_ready();

assert_eq!(cooldown.on_pending(), Outcome::Loop);
assert_eq!(cooldown.on_pending(), Outcome::Loop);
assert_eq!(cooldown.on_pending(), Outcome::Sleep);
assert_eq!(cooldown.on_pending(), Outcome::Sleep);

cooldown.on_ready();

// call on ready while we're still looping
assert_eq!(cooldown.on_pending(), Outcome::Loop);
cooldown.on_ready();

assert_eq!(cooldown.on_pending(), Outcome::Loop);
assert_eq!(cooldown.on_pending(), Outcome::Loop);
assert_eq!(cooldown.on_pending(), Outcome::Sleep);
assert_eq!(cooldown.on_pending(), Outcome::Sleep);
}

#[test]
fn disabled_test() {
let mut cooldown = Cooldown::new(0);

// with cooldown disabled, it should always return sleep
assert_eq!(cooldown.on_pending(), Outcome::Sleep);

cooldown.on_ready();
assert_eq!(cooldown.on_pending(), Outcome::Sleep);
}
}
1 change: 1 addition & 0 deletions quic/s2n-quic-platform/src/io/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ impl Io {
clock,
tx,
rx,
cooldown: Default::default(),
};
let join = executor.spawn(event_loop.start());
Ok((join, handle))
Expand Down
4 changes: 2 additions & 2 deletions quic/s2n-quic-platform/src/io/testing/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::{io, sync::Arc};

/// A task to receive on a socket
pub async fn rx(socket: Socket, producer: ring::Producer<Message>) -> io::Result<()> {
let result = task::Receiver::new(producer, socket).await;
let result = task::Receiver::new(producer, socket, Default::default()).await;
if let Some(err) = result {
Err(err)
} else {
Expand All @@ -28,7 +28,7 @@ pub async fn rx(socket: Socket, producer: ring::Producer<Message>) -> io::Result

/// A task to send on a socket
pub async fn tx(socket: Socket, consumer: ring::Consumer<Message>, gso: Gso) -> io::Result<()> {
let result = task::Sender::new(consumer, socket, gso).await;
let result = task::Sender::new(consumer, socket, gso, Default::default()).await;
if let Some(err) = result {
Err(err)
} else {
Expand Down
29 changes: 25 additions & 4 deletions quic/s2n-quic-platform/src/io/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use s2n_quic_core::{
inet::{self, SocketAddress},
io::event_loop::EventLoop,
path::MaxMtu,
task::cooldown::Cooldown,
time::Clock as ClockTrait,
};
use std::{convert::TryInto, io, io::ErrorKind};
Expand Down Expand Up @@ -171,17 +172,21 @@ impl Io {

let rx_socket_count = parse_env("S2N_QUIC_UNSTABLE_RX_SOCKET_COUNT").unwrap_or(1);

// configure the number of self-wakes before "cooling down" and waiting for epoll to
// complete
let rx_cooldown = cooldown("RX");

for idx in 0usize..rx_socket_count {
let (producer, consumer) = socket::ring::pair(entries, payload_len);
consumers.push(consumer);

// spawn a task that actually reads from the socket into the ring buffer
if idx + 1 == rx_socket_count {
handle.spawn(task::rx(rx_socket, producer));
handle.spawn(task::rx(rx_socket, producer, rx_cooldown));
break;
} else {
let rx_socket = rx_socket.try_clone()?;
handle.spawn(task::rx(rx_socket, producer));
handle.spawn(task::rx(rx_socket, producer, rx_cooldown.clone()));
}
}

Expand Down Expand Up @@ -214,17 +219,26 @@ impl Io {

let tx_socket_count = parse_env("S2N_QUIC_UNSTABLE_TX_SOCKET_COUNT").unwrap_or(1);

// configure the number of self-wakes before "cooling down" and waiting for epoll to
// complete
let tx_cooldown = cooldown("TX");

for idx in 0usize..tx_socket_count {
let (producer, consumer) = socket::ring::pair(entries, payload_len);
producers.push(producer);

// spawn a task that actually flushes the ring buffer to the socket
if idx + 1 == tx_socket_count {
handle.spawn(task::tx(tx_socket, consumer, gso.clone()));
handle.spawn(task::tx(tx_socket, consumer, gso.clone(), tx_cooldown));
break;
} else {
let tx_socket = tx_socket.try_clone()?;
handle.spawn(task::tx(tx_socket, consumer, gso.clone()));
handle.spawn(task::tx(
tx_socket,
consumer,
gso.clone(),
tx_cooldown.clone(),
));
}
}

Expand All @@ -241,6 +255,7 @@ impl Io {
clock,
rx,
tx,
cooldown: cooldown("ENDPOINT"),
}
.start(),
);
Expand All @@ -259,3 +274,9 @@ fn convert_addr_to_std(addr: socket2::SockAddr) -> io::Result<std::net::SocketAd
fn parse_env<T: core::str::FromStr>(name: &str) -> Option<T> {
std::env::var(name).ok().and_then(|v| v.parse().ok())
}

pub fn cooldown(direction: &str) -> Cooldown {
let name = format!("S2N_QUIC_UNSTABLE_COOLDOWN_{direction}");
let limit = parse_env(&name).unwrap_or(0);
Cooldown::new(limit)
}
7 changes: 5 additions & 2 deletions quic/s2n-quic-platform/src/io/tokio/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,23 @@ macro_rules! libc_msg {
mod $message {
use super::unix;
use crate::{features::Gso, message::$message::Message, socket::ring};
use s2n_quic_core::task::cooldown::Cooldown;

pub async fn rx<S: Into<std::net::UdpSocket>>(
socket: S,
producer: ring::Producer<Message>,
cooldown: Cooldown,
) -> std::io::Result<()> {
unix::rx(socket, producer).await
unix::rx(socket, producer, cooldown).await
}

pub async fn tx<S: Into<std::net::UdpSocket>>(
socket: S,
consumer: ring::Consumer<Message>,
gso: Gso,
cooldown: Cooldown,
) -> std::io::Result<()> {
unix::tx(socket, consumer, gso).await
unix::tx(socket, consumer, gso, cooldown).await
}
}
};
Expand Down
7 changes: 5 additions & 2 deletions quic/s2n-quic-platform/src/io/tokio/task/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@ use crate::{
syscall::SocketEvents,
};
use core::task::{Context, Poll};
use s2n_quic_core::task::cooldown::Cooldown;
use tokio::{io, net::UdpSocket};

pub async fn rx<S: Into<std::net::UdpSocket>>(
socket: S,
producer: ring::Producer<Message>,
cooldown: Cooldown,
) -> io::Result<()> {
let socket = socket.into();
socket.set_nonblocking(true).unwrap();

let socket = UdpSocket::from_std(socket).unwrap();
let result = task::Receiver::new(producer, socket).await;
let result = task::Receiver::new(producer, socket, cooldown).await;
if let Some(err) = result {
Err(err)
} else {
Expand All @@ -33,12 +35,13 @@ pub async fn tx<S: Into<std::net::UdpSocket>>(
socket: S,
consumer: ring::Consumer<Message>,
gso: Gso,
cooldown: Cooldown,
) -> io::Result<()> {
let socket = socket.into();
socket.set_nonblocking(true).unwrap();

let socket = UdpSocket::from_std(socket).unwrap();
let result = task::Sender::new(consumer, socket, gso).await;
let result = task::Sender::new(consumer, socket, gso, cooldown).await;
if let Some(err) = result {
Err(err)
} else {
Expand Down
Loading

0 comments on commit 2b5cb23

Please sign in to comment.