Skip to content

Commit

Permalink
feat: Accept handler (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
rklaehn authored Nov 14, 2024
2 parents a949899 + 30ce4cf commit 32d5bc1
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 107 deletions.
9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ quinn = { package = "iroh-quinn", version = "0.12", optional = true }
serde = { version = "1.0.183", features = ["derive"] }
tokio = { version = "1", default-features = false, features = ["macros", "sync"] }
tokio-serde = { version = "0.8", features = ["bincode"], optional = true }
tokio-util = { version = "0.7", features = ["codec"], optional = true }
tokio-util = { version = "0.7", features = ["rt"] }
tracing = "0.1"
hex = "0.4.3"
futures = { version = "0.3.30", optional = true }
Expand All @@ -52,12 +52,13 @@ proc-macro2 = "1.0.66"
futures-buffered = "0.2.4"
testresult = "0.4.1"
nested_enum_utils = "0.1.0"
tokio-util = { version = "0.7", features = ["rt"] }

[features]
hyper-transport = ["dep:flume", "dep:hyper", "dep:bincode", "dep:bytes", "dep:tokio-serde", "dep:tokio-util"]
quinn-transport = ["dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "dep:tokio-util"]
hyper-transport = ["dep:flume", "dep:hyper", "dep:bincode", "dep:bytes", "dep:tokio-serde", "tokio-util/codec"]
quinn-transport = ["dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "tokio-util/codec"]
flume-transport = ["dep:flume"]
iroh-net-transport = ["dep:iroh-net", "dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "dep:tokio-util"]
iroh-net-transport = ["dep:iroh-net", "dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "tokio-util/codec"]
macros = []
default = ["flume-transport"]

Expand Down
21 changes: 4 additions & 17 deletions examples/modularize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use app::AppService;
use futures_lite::StreamExt;
use futures_util::SinkExt;
use quic_rpc::{client::BoxedConnector, transport::flume, Listener, RpcClient, RpcServer};
use tracing::warn;

#[tokio::main]
async fn main() -> Result<()> {
Expand All @@ -32,23 +31,11 @@ async fn main() -> Result<()> {

async fn run_server<C: Listener<AppService>>(server_conn: C, handler: app::Handler) {
let server = RpcServer::<AppService, _>::new(server_conn);
loop {
let Ok(accepting) = server.accept().await else {
continue;
};
match accepting.read_first().await {
Err(err) => warn!(?err, "server accept failed"),
Ok((req, chan)) => {
let handler = handler.clone();
tokio::task::spawn(async move {
if let Err(err) = handler.handle_rpc_request(req, chan).await {
warn!(?err, "internal rpc error");
}
});
}
}
}
server
.accept_loop(move |req, chan| handler.clone().handle_rpc_request(req, chan))
.await
}

pub async fn client_demo(conn: BoxedConnector<AppService>) -> Result<()> {
let rpc_client = RpcClient::<AppService>::new(conn);
let client = app::Client::new(rpc_client.clone());
Expand Down
67 changes: 66 additions & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@ use std::{
marker::PhantomData,
pin::Pin,
result,
sync::Arc,
task::{self, Poll},
};

use futures_lite::{Future, Stream, StreamExt};
use futures_util::{SinkExt, TryStreamExt};
use pin_project::pin_project;
use tokio::sync::oneshot;
use tokio::{sync::oneshot, task::JoinSet};
use tokio_util::task::AbortOnDropHandle;
use tracing::{error, warn};

use crate::{
transport::{
Expand Down Expand Up @@ -211,6 +214,68 @@ impl<S: Service, C: Listener<S>> RpcServer<S, C> {
pub fn into_inner(self) -> C {
self.source
}

/// Run an accept loop for this server.
///
/// Each request will be handled in a separate task.
///
/// It is the caller's responsibility to poll the returned future to drive the server.
pub async fn accept_loop<Fun, Fut, E>(self, handler: Fun)
where
S: Service,
C: Listener<S>,
Fun: Fn(S::Req, RpcChannel<S, C>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), E>> + Send + 'static,
E: Into<anyhow::Error> + 'static,
{
let handler = Arc::new(handler);
let mut tasks = JoinSet::new();
loop {
tokio::select! {
Some(res) = tasks.join_next(), if !tasks.is_empty() => {
if let Err(e) = res {
if e.is_panic() {
error!("Panic handling RPC request: {e}");
}
}
}
req = self.accept() => {
let req = match req {
Ok(req) => req,
Err(e) => {
warn!("Error accepting RPC request: {e}");
continue;
}
};
let handler = handler.clone();
tasks.spawn(async move {
let (req, chan) = match req.read_first().await {
Ok((req, chan)) => (req, chan),
Err(e) => {
warn!("Error reading first message: {e}");
return;
}
};
if let Err(cause) = handler(req, chan).await {
warn!("Error handling RPC request: {}", cause.into());
}
});
}
}
}
}

/// Spawn an accept loop and return a handle to the task.
pub fn spawn_accept_loop<Fun, Fut, E>(self, handler: Fun) -> AbortOnDropHandle<()>
where
S: Service,
C: Listener<S>,
Fun: Fn(S::Req, RpcChannel<S, C>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), E>> + Send + 'static,
E: Into<anyhow::Error> + 'static,
{
AbortOnDropHandle::new(tokio::spawn(self.accept_loop(handler)))
}
}

impl<S: Service, C: Listener<S>> AsRef<C> for RpcServer<S, C> {
Expand Down
16 changes: 3 additions & 13 deletions tests/flume.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@ use quic_rpc::{
transport::flume,
RpcClient, RpcServer, Service,
};
use tokio_util::task::AbortOnDropHandle;

#[tokio::test]
async fn flume_channel_bench() -> anyhow::Result<()> {
tracing_subscriber::fmt::try_init().ok();
let (server, client) = flume::channel(1);

let server = RpcServer::<ComputeService, _>::new(server);
let server_handle = tokio::task::spawn(ComputeService::server(server));
let _server_handle = AbortOnDropHandle::new(tokio::spawn(ComputeService::server(server)));
let client = RpcClient::<ComputeService, _>::new(client);
bench(client, 1000000).await?;
// dropping the client will cause the server to terminate
match server_handle.await? {
Err(RpcServerError::Accept(_)) => {}
e => panic!("unexpected termination result {e:?}"),
}
Ok(())
}

Expand Down Expand Up @@ -101,13 +97,7 @@ async fn flume_channel_smoke() -> anyhow::Result<()> {
let (server, client) = flume::channel(1);

let server = RpcServer::<ComputeService, _>::new(server);
let server_handle = tokio::task::spawn(ComputeService::server(server));
let _server_handle = AbortOnDropHandle::new(tokio::spawn(ComputeService::server(server)));
smoke_test(client).await?;

// dropping the client will cause the server to terminate
match server_handle.await? {
Err(RpcServerError::Accept(_)) => {}
e => panic!("unexpected termination result {e:?}"),
}
Ok(())
}
21 changes: 5 additions & 16 deletions tests/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,13 @@ use tokio::task::JoinHandle;

mod math;
use math::*;
use tokio_util::task::AbortOnDropHandle;
mod util;

fn run_server(addr: &SocketAddr) -> JoinHandle<anyhow::Result<()>> {
fn run_server(addr: &SocketAddr) -> AbortOnDropHandle<()> {
let channel = HyperListener::serve(addr).unwrap();
let server = RpcServer::new(channel);
tokio::spawn(async move {
loop {
let server = server.clone();
ComputeService::server(server).await?;
}
#[allow(unreachable_code)]
anyhow::Ok(())
})
ComputeService::server(server)
}

#[derive(Debug, Serialize, Deserialize, From, TryInto)]
Expand Down Expand Up @@ -133,25 +127,21 @@ impl TestService {
async fn hyper_channel_bench() -> anyhow::Result<()> {
let addr: SocketAddr = "127.0.0.1:3000".parse()?;
let uri: Uri = "http://127.0.0.1:3000".parse()?;
let server_handle = run_server(&addr);
let _server_handle = run_server(&addr);
let client = HyperConnector::new(uri);
let client = RpcClient::new(client);
bench(client, 50000).await?;
println!("terminating server");
server_handle.abort();
let _ = server_handle.await;
Ok(())
}

#[tokio::test]
async fn hyper_channel_smoke() -> anyhow::Result<()> {
let addr: SocketAddr = "127.0.0.1:3001".parse()?;
let uri: Uri = "http://127.0.0.1:3001".parse()?;
let server_handle = run_server(&addr);
let _server_handle = run_server(&addr);
let client = HyperConnector::new(uri);
smoke_test(client).await?;
server_handle.abort();
let _ = server_handle.await;
Ok(())
}

Expand Down Expand Up @@ -302,6 +292,5 @@ async fn hyper_channel_errors() -> anyhow::Result<()> {

println!("terminating server");
server_handle.abort();
let _ = server_handle.await;
Ok(())
}
38 changes: 15 additions & 23 deletions tests/iroh-net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

use iroh_net::{key::SecretKey, NodeAddr};
use quic_rpc::{transport, RpcClient, RpcServer};
use tokio::task::JoinHandle;
use testresult::TestResult;

use crate::transport::iroh_net::{IrohNetConnector, IrohNetListener};

mod math;
use math::*;
use tokio_util::task::AbortOnDropHandle;
mod util;

const ALPN: &[u8] = b"quic-rpc/iroh-net/test";
Expand Down Expand Up @@ -44,13 +47,10 @@ impl Endpoints {
}
}

fn run_server(server: iroh_net::Endpoint) -> JoinHandle<anyhow::Result<()>> {
tokio::task::spawn(async move {
let connection = transport::iroh_net::IrohNetListener::new(server)?;
let server = RpcServer::new(connection);
ComputeService::server(server).await?;
anyhow::Ok(())
})
fn run_server(server: iroh_net::Endpoint) -> AbortOnDropHandle<()> {
let connection = IrohNetListener::new(server).unwrap();
let server = RpcServer::new(connection);
ComputeService::server(server)
}

// #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
Expand All @@ -64,17 +64,12 @@ async fn iroh_net_channel_bench() -> anyhow::Result<()> {
server_node_addr,
} = Endpoints::new().await?;
tracing::debug!("Starting server");
let server_handle = run_server(server);
let _server_handle = run_server(server);
tracing::debug!("Starting client");

let client = RpcClient::new(transport::iroh_net::IrohNetConnector::new(
client,
server_node_addr,
ALPN.into(),
));
let client = RpcClient::new(IrohNetConnector::new(client, server_node_addr, ALPN.into()));
tracing::debug!("Starting benchmark");
bench(client, 50000).await?;
server_handle.abort();
Ok(())
}

Expand All @@ -86,11 +81,9 @@ async fn iroh_net_channel_smoke() -> anyhow::Result<()> {
server,
server_node_addr,
} = Endpoints::new().await?;
let server_handle = run_server(server);
let client_connection =
transport::iroh_net::IrohNetConnector::new(client, server_node_addr, ALPN.into());
let _server_handle = run_server(server);
let client_connection = IrohNetConnector::new(client, server_node_addr, ALPN.into());
smoke_test(client_connection).await?;
server_handle.abort();
Ok(())
}

Expand All @@ -99,7 +92,7 @@ async fn iroh_net_channel_smoke() -> anyhow::Result<()> {
///
/// This is a regression test.
#[tokio::test]
async fn server_away_and_back() -> anyhow::Result<()> {
async fn server_away_and_back() -> TestResult<()> {
tracing_subscriber::fmt::try_init().ok();
tracing::info!("Creating endpoints");

Expand Down Expand Up @@ -128,7 +121,7 @@ async fn server_away_and_back() -> anyhow::Result<()> {
// create the RPC Server
let connection = transport::iroh_net::IrohNetListener::new(server_endpoint.clone())?;
let server = RpcServer::new(connection);
let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 1));
let server_handle = tokio::spawn(ComputeService::server_bounded(server, 1));

// wait a bit for connection due to Windows test failing on CI
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
Expand All @@ -151,7 +144,7 @@ async fn server_away_and_back() -> anyhow::Result<()> {
// make the server run again
let connection = transport::iroh_net::IrohNetListener::new(server_endpoint.clone())?;
let server = RpcServer::new(connection);
let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 5));
let server_handle = tokio::spawn(ComputeService::server_bounded(server, 5));

// wait a bit for connection due to Windows test failing on CI
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
Expand All @@ -163,7 +156,6 @@ async fn server_away_and_back() -> anyhow::Result<()> {
// server is running, this should work
let SqrResponse(response) = client.rpc(Sqr(3)).await?;
assert_eq!(response, 9);

server_handle.abort();
Ok(())
}
23 changes: 9 additions & 14 deletions tests/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use quic_rpc::{
};
use serde::{Deserialize, Serialize};
use thousands::Separable;
use tokio_util::task::AbortOnDropHandle;

/// compute the square of a number
#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -163,20 +164,14 @@ impl ComputeService {
}
}

pub async fn server<C: Listener<ComputeService>>(
pub fn server<C: Listener<ComputeService>>(
server: RpcServer<ComputeService, C>,
) -> result::Result<(), RpcServerError<C>> {
let s = server;
let service = ComputeService;
loop {
let (req, chan) = s.accept().await?.read_first().await?;
let service = service.clone();
tokio::spawn(async move { Self::handle_rpc_request(service, req, chan).await });
}
) -> AbortOnDropHandle<()> {
server.spawn_accept_loop(|req, chan| Self::handle_rpc_request(ComputeService, req, chan))
}

pub async fn handle_rpc_request<E>(
service: ComputeService,
self,
req: ComputeRequest,
chan: RpcChannel<ComputeService, E>,
) -> Result<(), RpcServerError<E>>
Expand All @@ -186,10 +181,10 @@ impl ComputeService {
use ComputeRequest::*;
#[rustfmt::skip]
match req {
Sqr(msg) => chan.rpc(msg, service, ComputeService::sqr).await,
Sum(msg) => chan.client_streaming(msg, service, ComputeService::sum).await,
Fibonacci(msg) => chan.server_streaming(msg, service, ComputeService::fibonacci).await,
Multiply(msg) => chan.bidi_streaming(msg, service, ComputeService::multiply).await,
Sqr(msg) => chan.rpc(msg, self, Self::sqr).await,
Sum(msg) => chan.client_streaming(msg, self, Self::sum).await,
Fibonacci(msg) => chan.server_streaming(msg, self, Self::fibonacci).await,
Multiply(msg) => chan.bidi_streaming(msg, self, Self::multiply).await,
MultiplyUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
SumUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
}?;
Expand Down
Loading

0 comments on commit 32d5bc1

Please sign in to comment.