Skip to content

Commit

Permalink
Merge pull request #21 from exyi/ssl-root-cert
Browse files Browse the repository at this point in the history
Add --ssl-root-cert option
  • Loading branch information
exyi authored May 28, 2024
2 parents 22088f3 + 3390904 commit 0ce724e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
5 changes: 4 additions & 1 deletion cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,15 @@ pub struct PostgresConnArgs {
/// Controls whether to use SSL/TLS to connect to the server.
#[arg(long="sslmode", alias="tlsmode", alias="ssl-mode", alias="tls-mode")]
sslmode: Option<SslMode>,
/// File with a TLS root certificate in PEM or DER (.crt) format. When specified, the default CA certificates are considered untrusted. The option can be specified multiple times. Using this options implies --sslmode=require.
#[arg(long="ssl-root-cert", alias="tls-root-cert")]
ssl_root_cert: Option<Vec<PathBuf>>
}

impl std::fmt::Debug for PostgresConnArgs {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let password = self.password.as_ref().map(|_| "********");
f.debug_struct("PostgresConnArgs").field("host", &self.host).field("user", &self.user).field("dbname", &self.dbname).field("port", &self.port).field("password", &password).field("sslmode", &self.sslmode).finish()
f.debug_struct("PostgresConnArgs").field("host", &self.host).field("user", &self.user).field("dbname", &self.dbname).field("port", &self.port).field("password", &password).field("sslmode", &self.sslmode).field("ssl_root_cert", &self.ssl_root_cert).finish()
}
}

Expand Down
40 changes: 35 additions & 5 deletions cli/src/postgres_cloner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,38 @@ fn read_password(user: &str) -> Result<String, String> {
}

#[cfg(any(target_os = "macos", target_os="windows", all(target_os="linux", not(target_env="musl"), any(target_arch="x86_64", target_arch="aarch64"))))]
fn build_tls_connector() -> Result<postgres_native_tls::MakeTlsConnector, String> {
let connector = native_tls::TlsConnector::new().map_err(|e| format!("Creating TLS connector failed: {}", e.to_string()))?;
fn build_tls_connector(certificates: &Option<Vec<PathBuf>>) -> Result<postgres_native_tls::MakeTlsConnector, String> {
fn load_cert(f: &PathBuf) -> Result<native_tls::Certificate, String> {
let bytes = std::fs::read(f).map_err(|e| format!("Failed to read certificate file {:?}: {}", f, e))?;
if let Ok(pem) = native_tls::Certificate::from_pem(&bytes) {
return Ok(pem);
}
if let Ok(der) = native_tls::Certificate::from_der(&bytes) {
return Ok(der);
}

Err(format!("Failed to load certificate from file {:?}", f))
}
let mut builder = native_tls::TlsConnector::builder();
match certificates {
None => {},
Some(certificates) => {
builder.disable_built_in_roots(true);
for cert in certificates {
builder.add_root_certificate(load_cert(cert)?);
}
}
}
let connector = builder.build().map_err(|e| format!("Creating TLS connector failed: {}", e.to_string()))?;
let pg_connector = postgres_native_tls::MakeTlsConnector::new(connector);
Ok(pg_connector)
}

#[cfg(not(any(target_os = "macos", target_os="windows", all(target_os="linux", not(target_env="musl"), any(target_arch="x86_64", target_arch="aarch64")))))]
fn build_tls_connector() -> Result<NoTls, String> {
fn build_tls_connector(certificates: &Option<Vec<PathBuf>>) -> Result<NoTls, String> {
if certificates.is_some() {
return Err("SSL/TLS is not supported in this build of pg2parquet".to_string());
}
Ok(NoTls)
}

Expand Down Expand Up @@ -146,7 +170,13 @@ fn pg_connect(args: &PostgresConnArgs) -> Result<Client, String> {
Some(x) => return Err(format!("SSL/TLS is disabled in this build of pg2parquet, so ssl mode {:?} cannot be used. Only 'disable' option is allowed.", x)),
}
match &args.sslmode {
None => {},
None => {
if args.ssl_root_cert.is_some() {
pg_config.ssl_mode(postgres::config::SslMode::Require);
} else {
pg_config.ssl_mode(postgres::config::SslMode::Prefer);
}
},
Some(crate::SslMode::Disable) => {
pg_config.ssl_mode(postgres::config::SslMode::Disable);
},
Expand All @@ -158,7 +188,7 @@ fn pg_connect(args: &PostgresConnArgs) -> Result<Client, String> {
},
}

let connector = build_tls_connector()?;
let connector = build_tls_connector(&args.ssl_root_cert)?;

let client = pg_config.connect(connector).map_err(|e| format!("DB connection failed: {}", e.to_string()))?;

Expand Down

0 comments on commit 0ce724e

Please sign in to comment.