Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch connection argument support #302

Merged
merged 5 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
/edgeql_python.cpython-*.so
__pycache__
/Cargo.lock
/.idea
113 changes: 108 additions & 5 deletions edgedb-tokio/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pub struct Builder {
unix_path: Option<PathBuf>,
user: Option<String>,
database: Option<String>,
branch: Option<String>,
password: Option<String>,
tls_ca_file: Option<PathBuf>,
tls_security: Option<TlsSecurity>,
Expand Down Expand Up @@ -109,6 +110,7 @@ pub(crate) struct ConfigInner {
pub secret_key: Option<String>,
pub cloud_profile: Option<String>,
pub database: String,
pub branch: String,
pub verifier: Verifier,
pub wait: Duration,
pub connect_timeout: Duration,
Expand Down Expand Up @@ -538,6 +540,21 @@ impl<'a> DsnHelper<'a> {
}).await
}

async fn retrieve_branch(&mut self) -> Result<Option<String>, Error> {
let v = self.url.path().strip_prefix("/").and_then(|s| {
if s.is_empty() {
None
} else {
Some(s.to_owned())
}
});
self.retrieve_value("branch", v, |s| {
let s = s.strip_prefix("/").unwrap_or(&s);
validate_branch(&s)?;
Ok(s.to_owned())
}).await
}

async fn retrieve_secret_key(&mut self) -> Result<Option<String>, Error> {
self.retrieve_value("secret_key", None, |s| Ok(s)).await
}
Expand Down Expand Up @@ -678,6 +695,13 @@ impl Builder {
Ok(self)
}

/// Set the branch name.
pub fn branch(&mut self, branch: &str) -> Result<&mut Self, Error> {
validate_branch(branch)?;
self.branch = Some(branch.into());
Ok(self)
}

/// Set certificate authority for TLS from file
///
/// Note: file is not read immediately but is read when configuration is
Expand Down Expand Up @@ -817,6 +841,9 @@ impl Builder {
database: self.database.clone()
.or_else(|| creds.map(|c| c.database.clone()).flatten())
.unwrap_or_else(|| "edgedb".into()),
branch: self.branch.clone()
.or_else(|| creds.map(|c| c.branch.clone()).flatten())
.unwrap_or_else(|| "__default__".into()),
instance_name: None,
wait: self.wait_until_available.unwrap_or(DEFAULT_WAIT),
connect_timeout: self.connect_timeout
Expand Down Expand Up @@ -934,6 +961,13 @@ impl Builder {
let full_path = resolve_unix(unix_path, port, self.admin);
cfg.address = Address::Unix(full_path);
}
if let Some(database) = &self.database {
if let Some(branch) = &self.branch {
errors.push(InvalidArgumentError::with_message(format!(
"database {} conflicts with branch {}", database, branch
)))
}
}
}

async fn granular_owned(&self, cfg: &mut ConfigInner,
Expand All @@ -943,6 +977,10 @@ impl Builder {
cfg.database = database.clone();
}

if let Some(branch) = &self.branch {
cfg.branch = branch.clone();
}

if let Some(user) = &self.user {
cfg.user = user.clone();
}
Expand Down Expand Up @@ -1093,6 +1131,16 @@ impl Builder {
cfg.database = database;
}

let branch = self.branch.clone().or_else(|| {
get_env("EDGEDB_BRANCH")
.and_then(|v| v.map(validate_branch).transpose())
.map_err(|e| errors.push(e)).ok().flatten()
});

if let Some(branch) = branch {
cfg.branch = branch;
}

let user = self.user.clone().or_else(|| {
get_env("EDGEDB_USER")
.and_then(|v| v.map(validate_user).transpose())
Expand Down Expand Up @@ -1202,15 +1250,46 @@ impl Builder {
} else {
dsn.ignore_value("password");
}
if self.database.is_none() {

let has_branch_option = dsn.query.contains_key("branch") || dsn.query.contains_key("branch_env") || dsn.query.contains_key("branch_file");
let has_database_option = dsn.query.contains_key("database") || dsn.query.contains_key("database_env") || dsn.query.contains_key("database_file");

if has_branch_option {
if has_database_option {
errors.push(InvalidArgumentError::with_message(
"Invalid DSN: `database` and `branch` cannot be present at the same time"
));
} else if self.database.is_some() {
errors.push(InvalidArgumentError::with_message(
"`branch` in DSN and `database` are mutually exclusive"
));
} else {
match dsn.retrieve_branch().await {
Ok(Some(value)) => cfg.branch = value,
Ok(None) => {},
Err(e) => errors.push(e)
}
}
} else if self.branch.is_some() {
if has_database_option {
errors.push(InvalidArgumentError::with_message(
"`database` in DSN and `branch` are mutually exclusive"
));
} else {
match dsn.retrieve_branch().await {
Ok(Some(value)) => cfg.branch = value,
Ok(None) => {},
Err(e) => errors.push(e)
}
}
} else {
match dsn.retrieve_database().await {
Ok(Some(value)) => cfg.database = value,
Ok(None) => {},
Err(e) => errors.push(e),
Err(e) => errors.push(e)
}
} else {
dsn.ignore_value("database");
}

match dsn.retrieve_secret_key().await {
Ok(Some(value)) => cfg.secret_key = Some(value),
Ok(None) => {},
Expand Down Expand Up @@ -1351,6 +1430,7 @@ impl Builder {
cloud_profile: None,
cloud_certs: None,
database: "edgedb".into(),
branch: "__default__".into(),
instance_name: None,
wait: self.wait_until_available.unwrap_or(DEFAULT_WAIT),
connect_timeout: self.connect_timeout
Expand Down Expand Up @@ -1566,6 +1646,7 @@ fn set_credentials(cfg: &mut ConfigInner, creds: &Credentials)
cfg.user = creds.user.clone();
cfg.password = creds.password.clone();
cfg.database = creds.database.clone().unwrap_or_else(|| "edgedb".into());
cfg.branch = creds.database.clone().unwrap_or_else(|| "__default__".into());
cfg.tls_security = creds.tls_security;
cfg.creds_file_outdated = creds.file_outdated;
Ok(())
Expand Down Expand Up @@ -1602,6 +1683,15 @@ fn validate_port(port: u16) -> Result<u16, Error> {
Ok(port)
}

fn validate_branch<T: AsRef<str>>(branch: T) -> Result<T, Error> {
if branch.as_ref().is_empty() {
return Err(InvalidArgumentError::with_message(
"invalid branch: empty string"
));
}
Ok(branch)
}

fn validate_database<T: AsRef<str>>(database: T) -> Result<T, Error> {
if database.as_ref().is_empty() {
return Err(InvalidArgumentError::with_message(
Expand Down Expand Up @@ -1658,7 +1748,8 @@ impl Config {
port: *port,
user: self.0.user.clone(),
password: self.0.password.clone(),
database: Some( self.0.database.clone()),
database: if self.0.branch == "__default__" { Some(self.0.database.clone()) } else { None },
branch: if self.0.branch == "__default__" { None } else { Some(self.0.branch.clone()) },
tls_ca: self.0.pem_certificates.clone(),
tls_security: self.0.tls_security,
file_outdated: false,
Expand All @@ -1674,6 +1765,7 @@ impl Config {
Address::Unix(path) => serde_json::json!(path.to_str().unwrap()),
},
"database": self.0.database,
"branch": self.0.branch,
"user": self.0.user,
"password": self.0.password,
"secretKey": self.0.secret_key,
Expand Down Expand Up @@ -1756,6 +1848,16 @@ impl Config {
Ok(self)
}

pub fn with_branch(mut self, branch: &str) -> Result<Config, Error> {
if branch.is_empty() {
return Err(InvalidArgumentError::with_message(
"invalid branch: empty string"
));
}
Arc::make_mut(&mut self.0).branch = branch.to_owned();
Ok(self)
}

/// Return the same config with changed wait until available timeout
#[cfg(any(feature="unstable", feature="test"))]
pub fn with_wait_until_available(mut self, wait: Duration) -> Config {
Expand Down Expand Up @@ -1969,6 +2071,7 @@ async fn from_dsn() {
));
assert_eq!(&cfg.0.user, "user1");
assert_eq!(&cfg.0.database, "db2");
assert_eq!(&cfg.0.branch, "__default__");
assert_eq!(cfg.0.password, Some("EiPhohl7".into()));

let cfg = Builder::new()
Expand Down
6 changes: 6 additions & 0 deletions edgedb-tokio/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub struct Credentials {
pub user: String,
pub password: Option<String>,
pub database: Option<String>,
pub branch: Option<String>,
pub tls_ca: Option<String>,
pub tls_security: TlsSecurity,
pub(crate) file_outdated: bool,
Expand All @@ -56,6 +57,8 @@ struct CredentialsCompat {
#[serde(default, skip_serializing_if="Option::is_none")]
database: Option<String>,
#[serde(default, skip_serializing_if="Option::is_none")]
branch: Option<String>,
#[serde(default, skip_serializing_if="Option::is_none")]
tls_cert_data: Option<String>, // deprecated
#[serde(default, skip_serializing_if="Option::is_none")]
tls_ca: Option<String>,
Expand Down Expand Up @@ -114,6 +117,7 @@ impl Default for Credentials {
user: "edgedb".into(),
password: None,
database: None,
branch: None,
tls_ca: None,
tls_security: TlsSecurity::Default,
file_outdated: false,
Expand All @@ -133,6 +137,7 @@ impl Serialize for Credentials {
user: self.user.clone(),
password: self.password.clone(),
database: self.database.clone(),
branch: self.branch.clone(),
tls_ca: self.tls_ca.clone(),
tls_cert_data: self.tls_ca.clone(),
tls_security: Some(self.tls_security),
Expand Down Expand Up @@ -192,6 +197,7 @@ impl<'de> Deserialize<'de> for Credentials {
user: creds.user,
password: creds.password,
database: creds.database,
branch: creds.branch,
tls_ca: creds.tls_ca.or(creds.tls_cert_data.clone()),
tls_security: creds.tls_security.unwrap_or(
match creds.tls_verify_hostname {
Expand Down