diff --git a/hook-worker/src/config.rs b/hook-worker/src/config.rs index ceb690f..51b23b7 100644 --- a/hook-worker/src/config.rs +++ b/hook-worker/src/config.rs @@ -37,6 +37,9 @@ pub struct Config { #[envconfig(default = "1")] pub dequeue_batch_size: u32, + + #[envconfig(default = "false")] + pub allow_internal_ips: bool, } impl Config { diff --git a/hook-worker/src/dns.rs b/hook-worker/src/dns.rs index e69de29..36fd7a0 100644 --- a/hook-worker/src/dns.rs +++ b/hook-worker/src/dns.rs @@ -0,0 +1,140 @@ +use std::error::Error as StdError; +use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; +use std::{fmt, io}; + +use futures::FutureExt; +use reqwest::dns::{Addrs, Name, Resolve, Resolving}; +use tokio::task::spawn_blocking; + +pub struct NoPublicIPv4Error; + +impl std::error::Error for NoPublicIPv4Error {} +impl fmt::Display for NoPublicIPv4Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "No public IPv4 found for specified host") + } +} +impl fmt::Debug for NoPublicIPv4Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "No public IPv4 found for specified host") + } +} + +/// Internal reqwest type, copied here as part of Resolving +pub(crate) type BoxError = Box; + +/// Returns [`true`] if the address appears to be a globally reachable IPv4. +/// +/// Trimmed down version of the unstable IpAddr::is_global, move to it when it's stable. +fn is_global_ipv4(addr: &SocketAddr) -> bool { + match addr.ip() { + IpAddr::V4(ip) => { + !(ip.octets()[0] == 0 // "This network" + || ip.is_private() + || ip.is_loopback() + || ip.is_link_local() + || ip.is_broadcast()) + } + IpAddr::V6(_) => false, // Our network does not currently support ipv6, let's ignore for now + } +} + +/// DNS resolver using the stdlib resolver, but filtering results to only pass public IPv4 results. +/// +/// Private and broadcast addresses are filtered out, so are IPv6 results for now (as our infra +/// does not currently support IPv6 routing anyway). +/// This is adapted from the GaiResolver in hyper and reqwest. +pub struct PublicIPv4Resolver {} + +impl Resolve for PublicIPv4Resolver { + fn resolve(&self, name: Name) -> Resolving { + // Closure to call the system's resolver (blocking call) through the ToSocketAddrs trait. + let resolve_host = move || (name.as_str(), 0).to_socket_addrs(); + + // Execute the blocking call in a separate worker thread then process its result asynchronously. + // spawn_blocking returns a JoinHandle that implements Future>. + let future_result = spawn_blocking(resolve_host).map(|result| match result { + Ok(Ok(all_addrs)) => { + // Resolution succeeded, filter the results + let filtered_addr: Vec = all_addrs.filter(is_global_ipv4).collect(); + if filtered_addr.is_empty() { + // No public IPs found, error out with PermissionDenied + let err: BoxError = Box::new(NoPublicIPv4Error); + Err(err) + } else { + // Pass remaining IPs in a boxed iterator for request to use. + let addrs: Addrs = Box::new(filtered_addr.into_iter()); + Ok(addrs) + } + } + Ok(Err(err)) => { + // Resolution failed, pass error through in a Box + let err: BoxError = Box::new(err); + Err(err) + } + Err(join_err) => { + // The tokio task failed, pass as io::Error in a Box + let err: BoxError = Box::new(io::Error::from(join_err)); + Err(err) + } + }); + + // Box the Future to satisfy the Resolving interface. + Box::pin(future_result) + } +} + +#[cfg(test)] +mod tests { + use crate::dns::{NoPublicIPv4Error, PublicIPv4Resolver}; + use reqwest::dns::{Name, Resolve}; + use std::str::FromStr; + + #[tokio::test] + async fn it_resolves_google_com() { + let resolver: PublicIPv4Resolver = PublicIPv4Resolver {}; + let addrs = resolver + .resolve(Name::from_str("google.com").unwrap()) + .await + .expect("lookup has failed"); + assert!(addrs.count() > 0, "empty address list") + } + + #[tokio::test] + async fn it_denies_ipv6_google_com() { + let resolver: PublicIPv4Resolver = PublicIPv4Resolver {}; + match resolver + .resolve(Name::from_str("ipv6.google.com").unwrap()) + .await + { + Ok(_) => panic!("should have failed"), + Err(err) => assert!(err.is::()), + } + } + + #[tokio::test] + async fn it_denies_localhost() { + let resolver: PublicIPv4Resolver = PublicIPv4Resolver {}; + match resolver.resolve(Name::from_str("localhost").unwrap()).await { + Ok(_) => panic!("should have failed"), + Err(err) => assert!(err.is::()), + } + } + + #[tokio::test] + async fn it_bubbles_up_resolution_error() { + let resolver: PublicIPv4Resolver = PublicIPv4Resolver {}; + match resolver + .resolve(Name::from_str("invalid.domain.unknown").unwrap()) + .await + { + Ok(_) => panic!("should have failed"), + Err(err) => { + assert!(!err.is::()); + assert!(err + .to_string() + .contains("failed to lookup address information")) + } + } + } +} diff --git a/hook-worker/src/error.rs b/hook-worker/src/error.rs index 48468bc..764e8d9 100644 --- a/hook-worker/src/error.rs +++ b/hook-worker/src/error.rs @@ -1,6 +1,8 @@ +use std::error::Error; use std::fmt; use std::time; +use crate::dns::NoPublicIPv4Error; use hook_common::{pgqueue, webhook::WebhookJobError}; use thiserror::Error; @@ -64,7 +66,11 @@ impl fmt::Display for WebhookRequestError { Some(m) => m.to_string(), None => "No response from the server".to_string(), }; - writeln!(f, "{}", error)?; + if is_error_source::(error) { + writeln!(f, "{}: {}", error, NoPublicIPv4Error)?; + } else { + writeln!(f, "{}", error)?; + } write!(f, "{}", response_message)?; Ok(()) @@ -132,3 +138,15 @@ pub enum WorkerError { #[error("timed out while waiting for jobs to be available")] TimeoutError, } + +/// Check the error and it's sources (recursively) to return true if an error of the given type is found. +/// TODO: use Error::sources() when stable +pub fn is_error_source(err: &(dyn std::error::Error + 'static)) -> bool { + if err.is::() { + return true; + } + match err.source() { + None => false, + Some(source) => is_error_source::(source), + } +} diff --git a/hook-worker/src/lib.rs b/hook-worker/src/lib.rs index 8488d15..94a0758 100644 --- a/hook-worker/src/lib.rs +++ b/hook-worker/src/lib.rs @@ -1,4 +1,5 @@ pub mod config; +pub mod dns; pub mod error; pub mod util; pub mod worker; diff --git a/hook-worker/src/main.rs b/hook-worker/src/main.rs index 8a6eeb3..050e2b9 100644 --- a/hook-worker/src/main.rs +++ b/hook-worker/src/main.rs @@ -52,6 +52,7 @@ async fn main() -> Result<(), WorkerError> { config.request_timeout.0, config.max_concurrent_jobs, retry_policy_builder.provide(), + config.allow_internal_ips, worker_liveness, ); diff --git a/hook-worker/src/worker.rs b/hook-worker/src/worker.rs index 824f1e2..fdb405a 100644 --- a/hook-worker/src/worker.rs +++ b/hook-worker/src/worker.rs @@ -14,11 +14,14 @@ use hook_common::{ webhook::{HttpMethod, WebhookJobError, WebhookJobMetadata, WebhookJobParameters}, }; use http::StatusCode; -use reqwest::header; +use reqwest::{header, Client}; use tokio::sync; use tracing::error; -use crate::error::{WebhookError, WebhookParseError, WebhookRequestError, WorkerError}; +use crate::dns::{NoPublicIPv4Error, PublicIPv4Resolver}; +use crate::error::{ + is_error_source, WebhookError, WebhookParseError, WebhookRequestError, WorkerError, +}; use crate::util::first_n_bytes_of_response; /// A WebhookJob is any `PgQueueJob` with `WebhookJobParameters` and `WebhookJobMetadata`. @@ -74,6 +77,25 @@ pub struct WebhookWorker<'p> { liveness: HealthHandle, } +pub fn build_http_client( + request_timeout: time::Duration, + allow_internal_ips: bool, +) -> reqwest::Result { + let mut headers = header::HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ); + let mut client_builder = reqwest::Client::builder() + .default_headers(headers) + .user_agent("PostHog Webhook Worker") + .timeout(request_timeout); + if !allow_internal_ips { + client_builder = client_builder.dns_resolver(Arc::new(PublicIPv4Resolver {})) + } + client_builder.build() +} + impl<'p> WebhookWorker<'p> { #[allow(clippy::too_many_arguments)] pub fn new( @@ -84,19 +106,10 @@ impl<'p> WebhookWorker<'p> { request_timeout: time::Duration, max_concurrent_jobs: usize, retry_policy: RetryPolicy, + allow_internal_ips: bool, liveness: HealthHandle, ) -> Self { - let mut headers = header::HeaderMap::new(); - headers.insert( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ); - - let client = reqwest::Client::builder() - .default_headers(headers) - .user_agent("PostHog Webhook Worker") - .timeout(request_timeout) - .build() + let client = build_http_client(request_timeout, allow_internal_ips) .expect("failed to construct reqwest client for webhook worker"); Self { @@ -390,10 +403,19 @@ async fn send_webhook( .body(body) .send() .await - .map_err(|e| WebhookRequestError::RetryableRequestError { - error: e, - response: None, - retry_after: None, + .map_err(|e| { + if is_error_source::(&e) { + WebhookRequestError::NonRetryableRetryableRequestError { + error: e, + response: None, + } + } else { + WebhookRequestError::RetryableRequestError { + error: e, + response: None, + retry_after: None, + } + } })?; let retry_after = parse_retry_after_header(response.headers()); @@ -469,6 +491,7 @@ fn parse_retry_after_header(header_map: &reqwest::header::HeaderMap) -> Option Client { + build_http_client(Duration::from_secs(1), true).expect("failed to create client") + } + #[allow(dead_code)] async fn enqueue_job( queue: &PgQueue, @@ -569,6 +598,7 @@ mod tests { time::Duration::from_millis(5000), 10, RetryPolicy::default(), + false, liveness, ); @@ -594,15 +624,14 @@ mod tests { assert!(registry.get_status().healthy) } - #[sqlx::test(migrations = "../migrations")] - async fn test_send_webhook(_pg: PgPool) { + #[tokio::test] + async fn test_send_webhook() { let method = HttpMethod::POST; let url = "http://localhost:18081/echo"; let headers = collections::HashMap::new(); let body = "a very relevant request body"; - let client = reqwest::Client::new(); - let response = send_webhook(client, &method, url, &headers, body.to_owned()) + let response = send_webhook(localhost_client(), &method, url, &headers, body.to_owned()) .await .expect("send_webhook failed"); @@ -613,15 +642,14 @@ mod tests { ); } - #[sqlx::test(migrations = "../migrations")] - async fn test_error_message_contains_response_body(_pg: PgPool) { + #[tokio::test] + async fn test_error_message_contains_response_body() { let method = HttpMethod::POST; let url = "http://localhost:18081/fail"; let headers = collections::HashMap::new(); let body = "this is an error message"; - let client = reqwest::Client::new(); - let err = send_webhook(client, &method, url, &headers, body.to_owned()) + let err = send_webhook(localhost_client(), &method, url, &headers, body.to_owned()) .await .err() .expect("request didn't fail when it should have failed"); @@ -638,17 +666,16 @@ mod tests { } } - #[sqlx::test(migrations = "../migrations")] - async fn test_error_message_contains_up_to_n_bytes_of_response_body(_pg: PgPool) { + #[tokio::test] + async fn test_error_message_contains_up_to_n_bytes_of_response_body() { let method = HttpMethod::POST; let url = "http://localhost:18081/fail"; let headers = collections::HashMap::new(); // This is double the current hardcoded amount of bytes. // TODO: Make this configurable and change it here too. let body = (0..20 * 1024).map(|_| "a").collect::>().concat(); - let client = reqwest::Client::new(); - let err = send_webhook(client, &method, url, &headers, body.to_owned()) + let err = send_webhook(localhost_client(), &method, url, &headers, body.to_owned()) .await .err() .expect("request didn't fail when it should have failed"); @@ -666,4 +693,32 @@ mod tests { )); } } + + #[tokio::test] + async fn test_private_ips_denied() { + let method = HttpMethod::POST; + let url = "http://localhost:18081/echo"; + let headers = collections::HashMap::new(); + let body = "a very relevant request body"; + let filtering_client = + build_http_client(Duration::from_secs(1), false).expect("failed to create client"); + + let err = send_webhook(filtering_client, &method, url, &headers, body.to_owned()) + .await + .err() + .expect("request didn't fail when it should have failed"); + + assert!(matches!(err, WebhookError::Request(..))); + if let WebhookError::Request(request_error) = err { + assert_eq!(request_error.status(), None); + assert!(request_error + .to_string() + .contains("No public IPv4 found for specified host")); + if let WebhookRequestError::RetryableRequestError { .. } = request_error { + panic!("error should not be retryable") + } + } else { + panic!("unexpected error type {:?}", err) + } + } }