From fd7d8d84f9ba7473eb9dcdb36be1eec35d88be74 Mon Sep 17 00:00:00 2001 From: Tim Geoghegan Date: Mon, 22 Aug 2022 15:56:47 -0700 Subject: [PATCH] Retries in HTTP clients Adds a new module `janus_core::retries` which uses crate `backoff` to allow retrying HTTP requests based on whether an error is transient. We retry if we encounter: - failures to connect (i.e., connection refused or reset); - timeouts; - HTTP errors in the 500 range, except 501 Not Implemented. The error classification depends on methods on `reqwest::Error` and walking the chain of `std::error::Error::source` errors to see if there is an `std::io::Error` under the covers. We retry on HTTP status 500, 502-599, but if we give up after encountering enough such errors, the caller gets `Err(Ok(reqwest::Response))` so that they can examine the response body or headers. Resolves #196 --- Cargo.lock | 5 + janus_client/Cargo.toml | 1 + janus_client/src/lib.rs | 60 +++- janus_core/Cargo.toml | 10 +- janus_core/src/lib.rs | 1 + janus_core/src/retries.rs | 311 +++++++++++++++++++++ monolithic_integration_test/tests/janus.rs | 4 - 7 files changed, 370 insertions(+), 22 deletions(-) create mode 100644 janus_core/src/retries.rs diff --git a/Cargo.lock b/Cargo.lock index 7cf7c5ae6..ef2731164 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1636,6 +1636,7 @@ name = "janus_client" version = "0.1.8" dependencies = [ "assert_matches", + "backoff", "derivative", "http", "janus_core", @@ -1656,6 +1657,7 @@ version = "0.1.8" dependencies = [ "anyhow", "assert_matches", + "backoff", "base64", "bytes", "chrono", @@ -1666,11 +1668,13 @@ dependencies = [ "janus_core", "k8s-openapi", "kube", + "mockito", "num_enum", "postgres-protocol", "postgres-types", "prio", "rand", + "reqwest", "ring", "serde", "serde_json", @@ -1680,6 +1684,7 @@ dependencies = [ "tracing", "tracing-log", "tracing-subscriber", + "url", ] [[package]] diff --git a/janus_client/Cargo.toml b/janus_client/Cargo.toml index 5c11cf066..d9473531f 100644 --- a/janus_client/Cargo.toml +++ b/janus_client/Cargo.toml @@ -10,6 +10,7 @@ repository = "https://github.com/divviup/janus" rust-version = "1.63" [dependencies] +backoff = { version = "0.4.0", features = ["tokio"] } derivative = "2.2.0" http = "0.2.8" janus_core = { version = "0.1", path = "../janus_core" } diff --git a/janus_client/src/lib.rs b/janus_client/src/lib.rs index d530afe72..a24d17d89 100644 --- a/janus_client/src/lib.rs +++ b/janus_client/src/lib.rs @@ -1,11 +1,13 @@ //! PPM protocol client +use backoff::ExponentialBackoff; use derivative::Derivative; use http::{header::CONTENT_TYPE, StatusCode}; use janus_core::{ hpke::associated_data_for_report_share, hpke::{self, HpkeApplicationInfo, Label}, message::{Duration, HpkeCiphertext, HpkeConfig, Nonce, Report, Role, TaskId}, + retries::{http_request_exponential_backoff, retry_http_request}, time::Clock, }; use prio::{ @@ -57,14 +59,31 @@ pub struct ClientParameters { /// The minimum batch duration of the task. This value is shared by all /// parties in the protocol, and is used to compute report nonces. min_batch_duration: Duration, + /// Parameters to use when retrying HTTP requests. + http_request_retry_parameters: ExponentialBackoff, } impl ClientParameters { /// Creates a new set of client task parameters. pub fn new( + task_id: TaskId, + aggregator_endpoints: Vec, + min_batch_duration: Duration, + ) -> Self { + Self::new_with_backoff( + task_id, + aggregator_endpoints, + min_batch_duration, + http_request_exponential_backoff(), + ) + } + + /// Creates a new set of client task parameters with non-default HTTP request retry parameters. + pub fn new_with_backoff( task_id: TaskId, mut aggregator_endpoints: Vec, min_batch_duration: Duration, + http_request_retry_parameters: ExponentialBackoff, ) -> Self { // Ensure provided aggregator endpoints end with a slash, as we will be joining additional // path segments into these endpoints & the Url::join implementation is persnickety about @@ -79,6 +98,7 @@ impl ClientParameters { task_id, aggregator_endpoints, min_batch_duration, + http_request_retry_parameters, } } @@ -122,7 +142,12 @@ pub async fn aggregator_hpke_config( ) -> Result { let mut request_url = client_parameters.hpke_config_endpoint(aggregator_role)?; request_url.set_query(Some(&format!("task_id={}", task_id))); - let hpke_config_response = http_client.get(request_url).send().await?; + let hpke_config_response = retry_http_request( + client_parameters.http_request_retry_parameters.clone(), + || async { http_client.get(request_url.clone()).send().await }, + ) + .await + .or_else(|e| e)?; let status = hpke_config_response.status(); if !status.is_success() { return Err(Error::Http(status)); @@ -217,21 +242,26 @@ where )) } - /// Upload a [`janus_core::message::Report`] to the leader, per §4.3.2 of - /// draft-gpew-priv-ppm. The provided measurement is sharded into one input - /// share plus one proof share for each aggregator and then uploaded to the - /// leader. + /// Upload a [`janus_core::message::Report`] to the leader, per §4.3.2 of draft-gpew-priv-ppm. + /// The provided measurement is sharded into one input share plus one proof share for each + /// aggregator and then uploaded to the leader. #[tracing::instrument(skip(measurement), err)] pub async fn upload(&self, measurement: &V::Measurement) -> Result<(), Error> { let report = self.prepare_report(measurement)?; - - let upload_response = self - .http_client - .post(self.parameters.upload_endpoint()?) - .header(CONTENT_TYPE, Report::MEDIA_TYPE) - .body(report.get_encoded()) - .send() - .await?; + let upload_endpoint = self.parameters.upload_endpoint()?; + let upload_response = retry_http_request( + self.parameters.http_request_retry_parameters.clone(), + || async { + self.http_client + .post(upload_endpoint.clone()) + .header(CONTENT_TYPE, Report::MEDIA_TYPE) + .body(report.get_encoded()) + .send() + .await + }, + ) + .await + .or_else(|e| e)?; let status = upload_response.status(); if !status.is_success() { // TODO(#233): decode an RFC 7807 problem document @@ -249,6 +279,7 @@ mod tests { use janus_core::{ hpke::test_util::generate_test_hpke_config_and_private_key, message::{TaskId, Time}, + retries::test_http_request_exponential_backoff, test_util::install_test_trace_subscriber, time::MockClock, }; @@ -262,10 +293,11 @@ mod tests { { let server_url = Url::parse(&mockito::server_url()).unwrap(); Client::new( - ClientParameters::new( + ClientParameters::new_with_backoff( TaskId::random(), Vec::from([server_url.clone(), server_url]), Duration::from_seconds(1), + test_http_request_exponential_backoff(), ), vdaf_client, MockClock::default(), diff --git a/janus_core/Cargo.toml b/janus_core/Cargo.toml index c9efb5819..677e4386c 100644 --- a/janus_core/Cargo.toml +++ b/janus_core/Cargo.toml @@ -13,12 +13,10 @@ rust-version = "1.63" database = ["dep:bytes", "dep:postgres-protocol", "dep:postgres-types"] test-util = [ "dep:assert_matches", - "dep:futures", "dep:kube", "dep:k8s-openapi", "dep:serde_json", "dep:tempfile", - "dep:tracing", "dep:tracing-log", "dep:tracing-subscriber", "tokio/macros", @@ -27,10 +25,12 @@ test-util = [ [dependencies] anyhow = "1" +backoff = { version = "0.4.0", features = ["tokio"] } base64 = "0.13.0" bytes = { version = "1.2.1", optional = true } chrono = "0.4" derivative = "2.2.0" +futures = "0.3.23" hex = "0.4" hpke-dispatch = "0.3.0" kube = { version = "0.65", optional = true, default-features = false, features = ["client", "rustls-tls"] } @@ -40,17 +40,17 @@ postgres-protocol = { version = "0.6.4", optional = true } postgres-types = { version = "0.2.4", optional = true } prio = "0.8.2" rand = "0.8" +reqwest = { version = "0.11.4", default-features = false, features = ["rustls-tls"] } ring = "0.16.20" serde = { version = "1.0.144", features = ["derive"] } thiserror = "1.0" tokio = { version = "^1.20", features = ["macros", "net", "rt"] } +tracing = "0.1.36" # Dependencies required only if feature "test-util" is enabled assert_matches = { version = "1", optional = true } serde_json = { version = "1.0.85", optional = true } -futures = { version = "0.3.23", optional = true } tempfile = { version = "3", optional = true } -tracing = { version = "0.1.36", optional = true } tracing-log = { version = "0.1.3", optional = true } tracing-subscriber = { version = "0.3", features = ["std", "env-filter", "fmt"], optional = true } @@ -62,3 +62,5 @@ janus_core = { path = ".", features = ["test-util"] } # lack of support for connecting to servers by IP addresses, which affects many # Kubernetes clusters. kube = { version = "0.65", features = ["openssl-tls"] } # ensure this remains compatible with the non-dev dependency +mockito = "0.31.0" +url = "2.2.2" diff --git a/janus_core/src/lib.rs b/janus_core/src/lib.rs index eb754e9c5..1d2c7244d 100644 --- a/janus_core/src/lib.rs +++ b/janus_core/src/lib.rs @@ -3,6 +3,7 @@ use tokio::task::JoinHandle; pub mod hpke; pub mod message; +pub mod retries; pub mod task; #[cfg(feature = "test-util")] pub mod test_util; diff --git a/janus_core/src/retries.rs b/janus_core/src/retries.rs new file mode 100644 index 000000000..ceeedf870 --- /dev/null +++ b/janus_core/src/retries.rs @@ -0,0 +1,311 @@ +//! Provides a simple interface for retrying fallible HTTP requests. + +use backoff::{future::retry, ExponentialBackoff}; +use futures::Future; +use reqwest::StatusCode; +use std::{error::Error as StdError, time::Duration}; +use tracing::{debug, warn}; + +/// Traverse chain of source errors looking for an `std::io::Error`. +fn find_io_error(original_error: &reqwest::Error) -> Option<&std::io::Error> { + let mut cause = original_error.source(); + while let Some(err) = cause { + if let Some(typed) = err.downcast_ref() { + return Some(typed); + } + cause = err.source(); + } + + None +} + +/// An [`ExponentialBackoff`] with parameters suitable for most HTTP requests. The parameters are +/// copied from the parameters used in the GCP Go SDK[1]. +/// +/// AWS doesn't give us specific guidance on what intervals to use, but the GCP implementation cites +/// AWS blog posts so the same parameters are probably fine for both, and most HTTP APIs for that +/// matter. +/// +/// [1]: https://github.com/googleapis/gax-go/blob/fbaf9882acf3297573f3a7cb832e54c7d8f40635/v2/call_option.go#L120 +pub fn http_request_exponential_backoff() -> ExponentialBackoff { + ExponentialBackoff { + initial_interval: Duration::from_secs(1), + max_interval: Duration::from_secs(30), + multiplier: 2.0, + max_elapsed_time: Some(Duration::from_secs(600)), + ..Default::default() + } +} +/// An [`ExponentialBackoff`] with parameters tuned for tests where we don't want to be retrying +/// for 10 minutes. +#[cfg(feature = "test-util")] +pub fn test_http_request_exponential_backoff() -> ExponentialBackoff { + ExponentialBackoff { + initial_interval: Duration::from_nanos(1), + max_interval: Duration::from_nanos(30), + multiplier: 2.0, + max_elapsed_time: Some(Duration::from_millis(10)), + ..Default::default() + } +} + +/// Executes the provided request function and awaits the returned future, retrying using the +/// parameters in the provided `ExponentialBackoff` if the [`reqwest::Error`] returned by +/// `request_fn` is: +/// +/// - a timeout +/// - a problem establishing a connection +/// - an HTTP status code indicating a server error +/// +/// If the request eventually succeeds, the value returned by `request_fn` is returned. If an +/// unretryable failure occurs or enough transient failures occur, then `Err(ret)` is returned, +/// where `ret` is the `Result` returned by the last call to +/// `request_fn`. Retryable failures are logged. +/// +/// # TODOs: +/// +/// This function could take a list of HTTP status codes that should be considered retryable, so +/// that a caller could opt to retry when it sees 408 Request Timeout or 429 Too Many Requests, but +/// since none of the servers this is currently used to communicate with ever return those statuses, +/// we don't yet need that feature. +pub async fn retry_http_request( + backoff: ExponentialBackoff, + request_fn: RequestFn, +) -> Result> +where + RequestFn: Fn() -> ResultFuture, + ResultFuture: Future>, +{ + retry(backoff, || async { + // In all branches in this match, we wrap the reqwest::Response or reqwest::Error up in a + // Result>>>, + // which allows us to retry on certain HTTP status codes without discarding the + // reqwest::Response, which the caller may need in order to examine its body or headers. + match request_fn().await { + Ok(response) => { + if response.status().is_server_error() + && response.status() != StatusCode::NOT_IMPLEMENTED + { + warn!(?response, "encountered retryable server error"); + return Err(backoff::Error::transient(Ok(response))); + } + + Ok(response) + } + Err(error) => { + if error.is_timeout() || error.is_connect() { + warn!(?error, "encountered retryable error"); + return Err(backoff::Error::transient(Err(error))); + } + + if let Some(io_error) = find_io_error(&error) { + if let std::io::ErrorKind::ConnectionRefused + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::ConnectionAborted = io_error.kind() + { + warn!(?error, "encountered retryable error"); + return Err(backoff::Error::transient(Err(error))); + } + } + + debug!("encountered non-retryable error"); + Err(backoff::Error::permanent(Err(error))) + } + } + }) + .await +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_util::install_test_trace_subscriber; + use mockito::mock; + use tokio::net::TcpListener; + use url::Url; + + #[tokio::test] + async fn http_retry_client_error() { + install_test_trace_subscriber(); + + let mock_404 = mock("GET", "/") + .with_status(StatusCode::NOT_FOUND.as_u16().into()) + .with_header("some-header", "some-value") + .with_body("some-body") + .expect(1) + .create(); + + let http_client = reqwest::Client::builder().build().unwrap(); + + // HTTP 404 should cause the client to give up after a single attempt, and the caller should + // get `Ok(reqwest::Response)`. + let response = retry_http_request(test_http_request_exponential_backoff(), || async { + http_client.get(mockito::server_url()).send().await + }) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + assert_eq!( + response.headers().get("some-header").unwrap(), + &"some-value" + ); + assert_eq!(response.text().await.unwrap(), "some-body".to_string()); + + mock_404.assert(); + } + + #[tokio::test] + async fn http_retry_server_error() { + install_test_trace_subscriber(); + + let mock_500 = mock("GET", "/") + .with_status(StatusCode::INTERNAL_SERVER_ERROR.as_u16().into()) + .with_header("some-header", "some-value") + .with_body("some-body") + .expect_at_least(2) + .create(); + + let http_client = reqwest::Client::builder().build().unwrap(); + + // We expect to eventually give up in the face of repeated HTTP 500, but the caller expects + // a `reqwest::Response` so they can examine the status code, headers and response body, + // which you can't get from a `reqwest::Error`. + let response = retry_http_request(test_http_request_exponential_backoff(), || async { + http_client.get(mockito::server_url()).send().await + }) + .await + .unwrap_err() + .unwrap(); + + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + response.headers().get("some-header").unwrap(), + &"some-value" + ); + assert_eq!(response.text().await.unwrap(), "some-body".to_string()); + mock_500.assert(); + } + + #[tokio::test] + async fn http_retry_server_error_unimplemented() { + install_test_trace_subscriber(); + + let mock_501 = mock("GET", "/") + .with_status(StatusCode::NOT_IMPLEMENTED.as_u16().into()) + .expect(1) + .create(); + + let http_client = reqwest::Client::builder().build().unwrap(); + + let response = retry_http_request(test_http_request_exponential_backoff(), || async { + http_client.get(mockito::server_url()).send().await + }) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED); + mock_501.assert(); + } + + #[tokio::test] + async fn http_retry_server_error_eventuall1y_succeeds() { + install_test_trace_subscriber(); + + let mock_500 = mock("GET", "/") + .with_status(500) + .expect_at_least(1) + .create(); + let mock_200 = mock("GET", "/").with_status(200).expect(1).create(); + + let http_client = reqwest::Client::builder().build().unwrap(); + + retry_http_request(test_http_request_exponential_backoff(), || async { + http_client.get(mockito::server_url()).send().await + }) + .await + .unwrap(); + + mock_200.assert(); + mock_500.assert(); + } + + #[tokio::test] + async fn http_retry_timeout() { + install_test_trace_subscriber(); + + let tcp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let bound_port = tcp_listener.local_addr().unwrap().port(); + + let listener_task = tokio::spawn(async move { + loop { + let (_socket, _) = tcp_listener.accept().await.unwrap(); + // Deliberately do nothing with the socket to force a timeout in the client + tokio::time::sleep(Duration::from_secs(10)).await; + } + }); + + let url = Url::parse(&format!("http://127.0.0.1:{bound_port}")).unwrap(); + + let http_client = reqwest::Client::builder() + // Aggressively short timeout to force a timeout error + .timeout(Duration::from_nanos(1)) + .build() + .unwrap(); + + let err = retry_http_request(test_http_request_exponential_backoff(), || async { + http_client.get(url.clone()).send().await + }) + .await + .unwrap_err() + .unwrap_err(); + assert!(err.is_timeout(), "error = {err}"); + + listener_task.abort(); + assert!(listener_task.await.unwrap_err().is_cancelled()); + } + + #[tokio::test] + async fn http_retry_connection_reset() { + install_test_trace_subscriber(); + + let tcp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let bound_port = tcp_listener.local_addr().unwrap().port(); + + let listener_task = tokio::spawn(async move { + // Accept connections on the TCP listener, then wait until we can read one byte from + // them (indicating that the client has sent something). If we read successfully, drop + // the socket so that the client will see a connection reset error. + loop { + let (socket, _) = tcp_listener.accept().await.unwrap(); + loop { + socket.readable().await.unwrap(); + + let mut buf = [0u8; 1]; + match socket.try_read(&mut buf) { + Ok(1) => break, + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + continue; + } + val => panic!("unexpected result from try_read {val:?}"), + } + } + drop(socket); + } + }); + + let url = Url::parse(&format!("http://127.0.0.1:{bound_port}")).unwrap(); + + let http_client = reqwest::Client::builder().build().unwrap(); + + retry_http_request(test_http_request_exponential_backoff(), || async { + http_client.get(url.clone()).send().await + }) + .await + .unwrap_err() + .unwrap_err(); + + listener_task.abort(); + assert!(listener_task.await.unwrap_err().is_cancelled()); + } +} diff --git a/monolithic_integration_test/tests/janus.rs b/monolithic_integration_test/tests/janus.rs index eb98965ff..9bbd3824a 100644 --- a/monolithic_integration_test/tests/janus.rs +++ b/monolithic_integration_test/tests/janus.rs @@ -97,10 +97,6 @@ impl JanusPair { ) .await; - // Wait just a bit to allow kubectl port-forwards to be ready - // TODO(#196): Remove this. - tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - (leader, helper) } (