Skip to content

Commit

Permalink
Fix certificate modes + ignore hostname validation (#399)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac authored Feb 27, 2025
1 parent ad9790b commit 70cf676
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 15 deletions.
2 changes: 1 addition & 1 deletion gel-stream/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "gel-stream"
license = "MIT/Apache-2.0"
version = "0.1.4"
version = "0.2.0"
authors = ["MagicStack Inc. <[email protected]>"]
edition = "2021"
description = "A library for streaming data between clients and servers."
Expand Down
19 changes: 15 additions & 4 deletions gel-stream/src/common/rustls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,17 +315,28 @@ impl ServerCertVerifier for IgnoreHostnameVerifier {
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
_server_name: &ServerName,
server_name: &ServerName,
ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
self.verifier.verify_server_cert(
match self.verifier.verify_server_cert(
end_entity,
intermediates,
&ServerName::DnsName(DnsName::try_from("").unwrap()),
server_name,
ocsp_response,
now,
)
) {
Ok(res) => Ok(res),
// This works because the name check is the last step in the verify process
Err(e)
if e == rustls::Error::InvalidCertificate(
rustls::CertificateError::NotValidForName,
) =>
{
Ok(ServerCertVerified::assertion())
}
Err(e) => Err(e),
}
}

fn verify_tls12_signature(
Expand Down
34 changes: 34 additions & 0 deletions gel-stream/tests/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,40 @@ tls_test! {
Ok(())
}

/// Test that we can override the SNI.
#[tokio::test]
#[ntest::timeout(30_000)]
async fn test_target_tcp_tls_sni_override_ignore_hostname<C: TlsDriver, S: TlsDriver>() -> Result<(), ConnectionError> {
let (addr, accept_task) = spawn_tls_server::<S>(
Some("www.google.com"),
TlsAlpn::default(),
None,
TlsClientCertVerify::Ignore,
)
.await?;

let connect_task = tokio::spawn(async move {
let target = Target::new_resolved_tls(
addr,
TlsParameters {
root_cert: TlsCert::Custom(vec![load_test_ca()]),
server_cert_verify: TlsServerCertVerify::IgnoreHostname,
sni_override: Some(Cow::Borrowed("www.google.com")),
..Default::default()
},
);
let mut stm = Connector::<C>::new_explicit(target).unwrap().connect().await.unwrap();
stm.write_all(b"Hello, world!").await.unwrap();
stm.shutdown().await?;
Ok::<_, std::io::Error>(())
});

accept_task.await.unwrap().unwrap();
connect_task.await.unwrap().unwrap();

Ok(())
}

/// Test that we can set the ALPN.
#[tokio::test]
async fn test_target_tcp_tls_alpn<C: TlsDriver, S: TlsDriver>() -> Result<(), ConnectionError> {
Expand Down
4 changes: 2 additions & 2 deletions gel-tokio/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "gel-tokio"
license = "MIT/Apache-2.0"
version = "0.9.5"
version = "0.9.6"
authors = ["MagicStack Inc. <[email protected]>"]
edition = "2021"
description = """
Expand All @@ -17,7 +17,7 @@ gel-protocol = { path = "../gel-protocol", version = "0.8", features = [
] }
gel-errors = { path = "../gel-errors", version = "0.5" }
gel-derive = { path = "../gel-derive", version = "0.7", optional = true }
gel-stream = { path = "../gel-stream", version = "0.1.4", features = ["client", "tokio", "rustls", "hickory", "keepalive"] }
gel-stream = { path = "../gel-stream", version = "0.2.0", features = ["client", "tokio", "rustls", "hickory", "keepalive"] }
gel-auth = { path = "../gel-auth", version = "0.1.3" }
tokio = { workspace = true, features = ["net", "time", "sync", "macros"] }
bytes = "1.5.0"
Expand Down
11 changes: 8 additions & 3 deletions gel-tokio/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::future::Future;

use base64::Engine;
use gel_stream::{Target, TlsAlpn, TlsCert, TlsParameters};
use log::debug;
use serde_json::from_slice;
use sha1::Digest;
use tokio::fs;
Expand Down Expand Up @@ -1945,9 +1946,9 @@ impl Config {
if let Some(cloud_certs) = self.0.cloud_certs {
tls.root_cert = TlsCert::WebpkiPlus(read_root_cert_pem(cloud_certs.root())?);
}
tls.server_cert_verify = self.0.compute_tls_security()?;
}
}
tls.server_cert_verify = self.0.compute_tls_security()?;
tls.alpn = TlsAlpn::new_str(&["edgedb-binary", "gel-binary"]);
tls.sni_override = match &self.0.tls_server_name {
Some(server_name) => Some(Cow::from(server_name.clone())),
Expand Down Expand Up @@ -1980,7 +1981,7 @@ impl ConfigInner {
pub(crate) fn compute_tls_security(&self) -> Result<gel_stream::TlsServerCertVerify, Error> {
use gel_stream::TlsServerCertVerify::*;

match (self.client_security, self.tls_security) {
let res = match (self.client_security, self.tls_security) {
(ClientSecurity::Strict, TlsSecurity::Insecure | TlsSecurity::NoHostVerification) => {
Err(ClientError::with_message(format!(
"client_security=strict and tls_security={} don't comply",
Expand All @@ -1994,7 +1995,11 @@ impl ConfigInner {
(_, TlsSecurity::Insecure) => Ok(Insecure),
(_, TlsSecurity::NoHostVerification) => Ok(IgnoreHostname),
(_, TlsSecurity::Strict) => Ok(VerifyFull),
}
};

debug!("compute_tls_security(client_security={:?}, tls_security={:?}, has_pem={:?}) = {:?}", self.client_security, self.tls_security, self.pem_certificates.is_some(), res);

res
}
}

Expand Down
18 changes: 13 additions & 5 deletions gel-tokio/src/raw/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::str;
use std::time::Duration;

use bytes::{Bytes, BytesMut};
use log::warn;
use log::{debug, warn};
use rand::{rng, Rng};
use tokio::io::ReadBuf;
use tokio::io::{AsyncRead, AsyncReadExt};
Expand Down Expand Up @@ -248,19 +248,22 @@ impl Connection {
async fn connect(cfg: &Config) -> Result<Connection, Error> {
let mut target = cfg.0.address.clone();
let tls = cfg.tls()?;
debug!("Connecting to {:?}, TLS: {:?}", target, tls);
target.try_set_tls(tls);

let start = Instant::now();
let wait = cfg.0.wait;
let warned = &mut false;
let mut retry = 0;
let conn = loop {
match connect_timeout(cfg, connect2(cfg, target.clone(), warned)).await {
Err(e) if is_temporary(&e) => {
log::debug!("Temporary connection error: {:#}", e);
if wait > start.elapsed() {
sleep(connect_sleep()).await;
sleep(connect_sleep(retry)).await;
retry += 1;
continue;
} else if wait > Duration::new(0, 0) {
} else if wait > Duration::ZERO {
return Err(e.context(format!("cannot establish connection for {wait:?}")));
} else {
return Err(e);
Expand Down Expand Up @@ -693,8 +696,13 @@ async fn _wait_message<'x>(
Ok(result)
}

fn connect_sleep() -> Duration {
Duration::from_millis(rng().random_range(10u64..200u64))
fn connect_sleep(retry: usize) -> Duration {
let rand = rng().random_range(10u64..200u64);
if retry > 0 {
Duration::from_millis(rand * retry as u64)
} else {
Duration::from_millis(rand)
}
}

async fn connect_timeout<F, T>(cfg: &Config, f: F) -> Result<T, Error>
Expand Down

0 comments on commit 70cf676

Please sign in to comment.