Skip to content

Commit

Permalink
feat(rust): Allow setting custom client options
Browse files Browse the repository at this point in the history
  • Loading branch information
PrettyWood committed Jan 30, 2025
1 parent 2eaee18 commit ad25eb8
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 16 deletions.
2 changes: 2 additions & 0 deletions crates/polars-io/src/cloud/object_store_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ fn url_and_creds_to_key(url: &Url, options: Option<&CloudOptions>) -> Vec<u8> {
config,
#[cfg(feature = "cloud")]
credential_provider,
#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
client_options: _,
}| {
CloudOptions2 {
max_retries: *max_retries,
Expand Down
79 changes: 63 additions & 16 deletions crates/polars-io/src/cloud/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::io::Read;
#[cfg(feature = "aws")]
use std::path::Path;
use std::str::FromStr;
use std::time::Duration;

#[cfg(feature = "aws")]
use object_store::aws::AmazonS3Builder;
Expand Down Expand Up @@ -82,6 +83,51 @@ pub struct CloudOptions {
#[cfg(feature = "cloud")]
#[cfg_attr(feature = "serde", serde(deserialize_with = "deserialize_or_default"))]
pub(crate) credential_provider: Option<PlCredentialProvider>,
#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
#[cfg_attr(feature = "serde", serde(skip))]
client_options: Option<PlClientOptions>,
}

#[derive(Clone, Debug, PartialEq, Hash, Eq)]
#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct PlClientOptions {
timeout: Option<Duration>,
connect_timeout: Option<Duration>,
allow_http: bool,
}

#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
impl Default for PlClientOptions {
fn default() -> Self {
Self {
// We set request timeout super high as the timeout isn't reset at ACK,
// but starts from the moment we start downloading a body.
// https://docs.rs/reqwest/latest/reqwest/struct.ClientBuilder.html#method.timeout
timeout: None,
// Concurrency can increase connection latency, so set to None, similar to default.
connect_timeout: None,
allow_http: true,
}
}
}

#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
impl From<PlClientOptions> for ClientOptions {
fn from(pl_opts: PlClientOptions) -> Self {
let mut opts = ClientOptions::new();
if let Some(timeout) = pl_opts.timeout {
opts = opts.with_timeout(timeout);
} else {
opts = opts.with_timeout_disabled();
}
if let Some(connect_timeout) = pl_opts.connect_timeout {
opts = opts.with_connect_timeout(connect_timeout);
} else {
opts = opts.with_connect_timeout_disabled();
}
opts.with_allow_http(pl_opts.allow_http)
}
}

#[cfg(all(feature = "serde", feature = "cloud"))]
Expand All @@ -108,6 +154,8 @@ impl CloudOptions {
config: None,
#[cfg(feature = "cloud")]
credential_provider: None,
#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
client_options: None,
});

&DEFAULT
Expand Down Expand Up @@ -221,18 +269,6 @@ fn get_retry_config(max_retries: usize) -> RetryConfig {
}
}

#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
pub(super) fn get_client_options() -> ClientOptions {
ClientOptions::new()
// We set request timeout super high as the timeout isn't reset at ACK,
// but starts from the moment we start downloading a body.
// https://docs.rs/reqwest/latest/reqwest/struct.ClientBuilder.html#method.timeout
.with_timeout_disabled()
// Concurrency can increase connection latency, so set to None, similar to default.
.with_connect_timeout_disabled()
.with_allow_http(true)
}

#[cfg(feature = "aws")]
fn read_config(
builder: &mut AmazonS3Builder,
Expand Down Expand Up @@ -282,6 +318,17 @@ impl CloudOptions {
self
}

#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
pub fn with_client_options(mut self, client_options: PlClientOptions) -> Self {
self.client_options = Some(client_options);
self
}

#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))]
pub fn client_options(&self) -> PlClientOptions {
self.client_options.clone().unwrap_or_default()
}

/// Set the configuration for AWS connections. This is the preferred API from rust.
#[cfg(feature = "aws")]
pub fn with_aws<I: IntoIterator<Item = (AmazonS3ConfigKey, impl Into<String>)>>(
Expand All @@ -300,7 +347,7 @@ impl CloudOptions {
use super::credential_provider::IntoCredentialProvider;

let mut builder = AmazonS3Builder::from_env()
.with_client_options(get_client_options())
.with_client_options(self.client_options().into())
.with_url(url);

read_config(
Expand Down Expand Up @@ -423,7 +470,7 @@ impl CloudOptions {
// The credential provider `self.credentials` is prioritized if it is set. We also need
// `from_env()` as it may source environment configured storage account name.
let mut builder =
MicrosoftAzureBuilder::from_env().with_client_options(get_client_options());
MicrosoftAzureBuilder::from_env().with_client_options(self.client_options().into());

if let Some(options) = &self.config {
let CloudConfig::Azure(options) = options else {
Expand Down Expand Up @@ -476,7 +523,7 @@ impl CloudOptions {
GoogleCloudStorageBuilder::new()
};

let mut builder = builder.with_client_options(get_client_options());
let mut builder = builder.with_client_options(self.client_options().into());

if let Some(options) = &self.config {
let CloudConfig::Gcp(options) = options else {
Expand Down Expand Up @@ -505,7 +552,7 @@ impl CloudOptions {
object_store::http::HttpBuilder::new()
.with_url(url)
.with_client_options({
let mut opts = super::get_client_options();
let mut opts: ClientOptions = self.client_options().into();
if let Some(CloudConfig::Http { headers }) = &self.config {
opts = opts.with_default_headers(try_build_http_header_map_from_items_slice(
headers.as_slice(),
Expand Down

0 comments on commit ad25eb8

Please sign in to comment.