Skip to content

Commit

Permalink
Merge pull request #1151 from sunng87/feature/ssl-negotiation
Browse files Browse the repository at this point in the history
feat: sslnegotiation and direct ssl for postgres 17
  • Loading branch information
sfackler authored Feb 2, 2025
2 parents 07b6878 + 720ffe8 commit c104b23
Show file tree
Hide file tree
Showing 27 changed files with 203 additions and 56 deletions.
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: '2'
services:
postgres:
image: postgres:14
image: docker.io/postgres:17
ports:
- 5433:5433
volumes:
Expand Down
2 changes: 1 addition & 1 deletion postgres-native-tls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ default = ["runtime"]
runtime = ["tokio-postgres/runtime"]

[dependencies]
native-tls = "0.2"
native-tls = { version = "0.2", features = ["alpn"] }
tokio = "1.0"
tokio-native-tls = "0.3"
tokio-postgres = { version = "0.7.11", path = "../tokio-postgres", default-features = false }
Expand Down
8 changes: 8 additions & 0 deletions postgres-native-tls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
//! ```
#![warn(rust_2018_idioms, clippy::all, missing_docs)]

use native_tls::TlsConnectorBuilder;
use std::future::Future;
use std::io;
use std::pin::Pin;
Expand Down Expand Up @@ -180,3 +181,10 @@ where
}
}
}

/// Set ALPN for `TlsConnectorBuilder`
///
/// This is required when using `sslnegotiation=direct`
pub fn set_postgresql_alpn(builder: &mut TlsConnectorBuilder) {
builder.request_alpns(&["postgresql"]);
}
17 changes: 16 additions & 1 deletion postgres-native-tls/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tokio_postgres::tls::TlsConnect;

#[cfg(feature = "runtime")]
use crate::MakeTlsConnector;
use crate::TlsConnector;
use crate::{set_postgresql_alpn, TlsConnector};

async fn smoke_test<T>(s: &str, tls: T)
where
Expand Down Expand Up @@ -42,6 +42,21 @@ async fn require() {
.await;
}

#[tokio::test]
async fn direct() {
let mut builder = native_tls::TlsConnector::builder();
builder.add_root_certificate(
Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(),
);
set_postgresql_alpn(&mut builder);
let connector = builder.build().unwrap();
smoke_test(
"user=ssl_user dbname=postgres sslmode=require sslnegotiation=direct",
TlsConnector::new(connector, "localhost"),
)
.await;
}

#[tokio::test]
async fn prefer() {
let connector = native_tls::TlsConnector::builder()
Expand Down
9 changes: 8 additions & 1 deletion postgres-openssl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use openssl::hash::MessageDigest;
use openssl::nid::Nid;
#[cfg(feature = "runtime")]
use openssl::ssl::SslConnector;
use openssl::ssl::{self, ConnectConfiguration, SslRef};
use openssl::ssl::{self, ConnectConfiguration, SslConnectorBuilder, SslRef};
use openssl::x509::X509VerifyResult;
use std::error::Error;
use std::fmt::{self, Debug};
Expand Down Expand Up @@ -250,3 +250,10 @@ fn tls_server_end_point(ssl: &SslRef) -> Option<Vec<u8>> {
};
cert.digest(md).ok().map(|b| b.to_vec())
}

/// Set ALPN for `SslConnectorBuilder`
///
/// This is required when using `sslnegotiation=direct`
pub fn set_postgresql_alpn(builder: &mut SslConnectorBuilder) -> Result<(), ErrorStack> {
builder.set_alpn_protos(b"\x0apostgresql")
}
13 changes: 13 additions & 0 deletions postgres-openssl/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ async fn require() {
.await;
}

#[tokio::test]
async fn direct() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
set_postgresql_alpn(&mut builder).unwrap();
let ctx = builder.build();
smoke_test(
"user=ssl_user dbname=postgres sslmode=require sslnegotiation=direct",
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
)
.await;
}

#[tokio::test]
async fn prefer() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
Expand Down
8 changes: 4 additions & 4 deletions postgres-protocol/src/message/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ pub struct ColumnFormats<'a> {
remaining: u16,
}

impl<'a> FallibleIterator for ColumnFormats<'a> {
impl FallibleIterator for ColumnFormats<'_> {
type Item = u16;
type Error = io::Error;

Expand Down Expand Up @@ -557,7 +557,7 @@ pub struct DataRowRanges<'a> {
remaining: u16,
}

impl<'a> FallibleIterator for DataRowRanges<'a> {
impl FallibleIterator for DataRowRanges<'_> {
type Item = Option<Range<usize>>;
type Error = io::Error;

Expand Down Expand Up @@ -645,7 +645,7 @@ pub struct ErrorField<'a> {
value: &'a [u8],
}

impl<'a> ErrorField<'a> {
impl ErrorField<'_> {
#[inline]
pub fn type_(&self) -> u8 {
self.type_
Expand Down Expand Up @@ -717,7 +717,7 @@ pub struct Parameters<'a> {
remaining: u16,
}

impl<'a> FallibleIterator for Parameters<'a> {
impl FallibleIterator for Parameters<'_> {
type Item = Oid;
type Error = io::Error;

Expand Down
4 changes: 2 additions & 2 deletions postgres-protocol/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ impl<'a> Array<'a> {
/// An iterator over the dimensions of an array.
pub struct ArrayDimensions<'a>(&'a [u8]);

impl<'a> FallibleIterator for ArrayDimensions<'a> {
impl FallibleIterator for ArrayDimensions<'_> {
type Item = ArrayDimension;
type Error = StdBox<dyn Error + Sync + Send>;

Expand Down Expand Up @@ -950,7 +950,7 @@ pub struct PathPoints<'a> {
buf: &'a [u8],
}

impl<'a> FallibleIterator for PathPoints<'a> {
impl FallibleIterator for PathPoints<'_> {
type Item = Point;
type Error = StdBox<dyn Error + Sync + Send>;

Expand Down
20 changes: 10 additions & 10 deletions postgres-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ pub enum Format {
Binary,
}

impl<'a, T> ToSql for &'a T
impl<T> ToSql for &T
where
T: ToSql,
{
Expand Down Expand Up @@ -963,7 +963,7 @@ impl<T: ToSql> ToSql for Option<T> {
to_sql_checked!();
}

impl<'a, T: ToSql> ToSql for &'a [T] {
impl<T: ToSql> ToSql for &[T] {
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
let member_type = match *ty.kind() {
Kind::Array(ref member) => member,
Expand Down Expand Up @@ -1004,7 +1004,7 @@ impl<'a, T: ToSql> ToSql for &'a [T] {
to_sql_checked!();
}

impl<'a> ToSql for &'a [u8] {
impl ToSql for &[u8] {
fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
types::bytea_to_sql(self, w);
Ok(IsNull::No)
Expand Down Expand Up @@ -1064,7 +1064,7 @@ impl<T: ToSql> ToSql for Box<[T]> {
to_sql_checked!();
}

impl<'a> ToSql for Cow<'a, [u8]> {
impl ToSql for Cow<'_, [u8]> {
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
<&[u8] as ToSql>::to_sql(&self.as_ref(), ty, w)
}
Expand All @@ -1088,7 +1088,7 @@ impl ToSql for Vec<u8> {
to_sql_checked!();
}

impl<'a> ToSql for &'a str {
impl ToSql for &str {
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
match ty.name() {
"ltree" => types::ltree_to_sql(self, w),
Expand All @@ -1109,7 +1109,7 @@ impl<'a> ToSql for &'a str {
to_sql_checked!();
}

impl<'a> ToSql for Cow<'a, str> {
impl ToSql for Cow<'_, str> {
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
<&str as ToSql>::to_sql(&self.as_ref(), ty, w)
}
Expand Down Expand Up @@ -1256,17 +1256,17 @@ impl BorrowToSql for &dyn ToSql {
}
}

impl<'a> sealed::Sealed for Box<dyn ToSql + Sync + 'a> {}
impl sealed::Sealed for Box<dyn ToSql + Sync + '_> {}

impl<'a> BorrowToSql for Box<dyn ToSql + Sync + 'a> {
impl BorrowToSql for Box<dyn ToSql + Sync + '_> {
#[inline]
fn borrow_to_sql(&self) -> &dyn ToSql {
self.as_ref()
}
}

impl<'a> sealed::Sealed for Box<dyn ToSql + Sync + Send + 'a> {}
impl<'a> BorrowToSql for Box<dyn ToSql + Sync + Send + 'a> {
impl sealed::Sealed for Box<dyn ToSql + Sync + Send + '_> {}
impl BorrowToSql for Box<dyn ToSql + Sync + Send + '_> {
#[inline]
fn borrow_to_sql(&self) -> &dyn ToSql {
self.as_ref()
Expand Down
16 changes: 15 additions & 1 deletion postgres/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::time::Duration;
use tokio::runtime;
#[doc(inline)]
pub use tokio_postgres::config::{
ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs,
ChannelBinding, Host, LoadBalanceHosts, SslMode, SslNegotiation, TargetSessionAttrs,
};
use tokio_postgres::error::DbError;
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
Expand Down Expand Up @@ -40,6 +40,9 @@ use tokio_postgres::{Error, Socket};
/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts
/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting
/// with the `connect` method.
/// * `sslnegotiation` - TLS negotiation method. If set to `direct`, the client will perform direct TLS handshake, this only works for PostgreSQL 17 and newer.
/// Note that you will need to setup ALPN of TLS client configuration to `postgresql` when using direct TLS.
/// If set to `postgres`, the default value, it follows original postgres wire protocol to perform the negotiation.
/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format,
/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses.
/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address,
Expand Down Expand Up @@ -230,6 +233,17 @@ impl Config {
self.config.get_ssl_mode()
}

/// Sets the SSL negotiation method
pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config {
self.config.ssl_negotiation(ssl_negotiation);
self
}

/// Gets the SSL negotiation method
pub fn get_ssl_negotiation(&self) -> SslNegotiation {
self.config.get_ssl_negotiation()
}

/// Adds a host to the configuration.
///
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
Expand Down
6 changes: 3 additions & 3 deletions postgres/src/notifications.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub struct Iter<'a> {
connection: ConnectionRef<'a>,
}

impl<'a> FallibleIterator for Iter<'a> {
impl FallibleIterator for Iter<'_> {
type Item = Notification;
type Error = Error;

Expand All @@ -100,7 +100,7 @@ pub struct BlockingIter<'a> {
connection: ConnectionRef<'a>,
}

impl<'a> FallibleIterator for BlockingIter<'a> {
impl FallibleIterator for BlockingIter<'_> {
type Item = Notification;
type Error = Error;

Expand Down Expand Up @@ -129,7 +129,7 @@ pub struct TimeoutIter<'a> {
timeout: Duration,
}

impl<'a> FallibleIterator for TimeoutIter<'a> {
impl FallibleIterator for TimeoutIter<'_> {
type Item = Notification;
type Error = Error;

Expand Down
2 changes: 1 addition & 1 deletion postgres/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub struct Transaction<'a> {
transaction: Option<tokio_postgres::Transaction<'a>>,
}

impl<'a> Drop for Transaction<'a> {
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if let Some(transaction) = self.transaction.take() {
let _ = self.connection.block_on(transaction.rollback());
Expand Down
15 changes: 12 additions & 3 deletions tokio-postgres/src/cancel_query.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::client::SocketConfig;
use crate::config::SslMode;
use crate::config::{SslMode, SslNegotiation};
use crate::tls::MakeTlsConnect;
use crate::{cancel_query_raw, connect_socket, Error, Socket};
use std::io;

pub(crate) async fn cancel_query<T>(
config: Option<SocketConfig>,
ssl_mode: SslMode,
ssl_negotiation: SslNegotiation,
mut tls: T,
process_id: i32,
secret_key: i32,
Expand Down Expand Up @@ -38,6 +39,14 @@ where
)
.await?;

cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, has_hostname, process_id, secret_key)
.await
cancel_query_raw::cancel_query_raw(
socket,
ssl_mode,
ssl_negotiation,
tls,
has_hostname,
process_id,
secret_key,
)
.await
}
5 changes: 3 additions & 2 deletions tokio-postgres/src/cancel_query_raw.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::config::SslMode;
use crate::config::{SslMode, SslNegotiation};
use crate::tls::TlsConnect;
use crate::{connect_tls, Error};
use bytes::BytesMut;
Expand All @@ -8,6 +8,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
pub async fn cancel_query_raw<S, T>(
stream: S,
mode: SslMode,
negotiation: SslNegotiation,
tls: T,
has_hostname: bool,
process_id: i32,
Expand All @@ -17,7 +18,7 @@ where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
let mut stream = connect_tls::connect_tls(stream, mode, tls, has_hostname).await?;
let mut stream = connect_tls::connect_tls(stream, mode, negotiation, tls, has_hostname).await?;

let mut buf = BytesMut::new();
frontend::cancel_request(process_id, secret_key, &mut buf);
Expand Down
5 changes: 4 additions & 1 deletion tokio-postgres/src/cancel_token.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::config::SslMode;
use crate::config::{SslMode, SslNegotiation};
use crate::tls::TlsConnect;
#[cfg(feature = "runtime")]
use crate::{cancel_query, client::SocketConfig, tls::MakeTlsConnect, Socket};
Expand All @@ -12,6 +12,7 @@ pub struct CancelToken {
#[cfg(feature = "runtime")]
pub(crate) socket_config: Option<SocketConfig>,
pub(crate) ssl_mode: SslMode,
pub(crate) ssl_negotiation: SslNegotiation,
pub(crate) process_id: i32,
pub(crate) secret_key: i32,
}
Expand All @@ -37,6 +38,7 @@ impl CancelToken {
cancel_query::cancel_query(
self.socket_config.clone(),
self.ssl_mode,
self.ssl_negotiation,
tls,
self.process_id,
self.secret_key,
Expand All @@ -54,6 +56,7 @@ impl CancelToken {
cancel_query_raw::cancel_query_raw(
stream,
self.ssl_mode,
self.ssl_negotiation,
tls,
true,
self.process_id,
Expand Down
Loading

0 comments on commit c104b23

Please sign in to comment.