Skip to content

Commit

Permalink
Add support for checksum algorithms in AWS (#3873)
Browse files Browse the repository at this point in the history
* Add support for checksum algorithms in aws

* Remove other algorithms

* Only set when checksum algorithm is sha256

* Fix
  • Loading branch information
trueleo authored Mar 21, 2023
1 parent 5d3307a commit 90cb00d
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 10 deletions.
51 changes: 51 additions & 0 deletions object_store/src/aws/checksum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use ring::digest::{self, digest as ring_digest};

#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
/// Enum representing checksum algorithm supported by S3.
pub enum Checksum {
/// SHA-256 algorithm.
SHA256,
}

impl Checksum {
pub(super) fn digest(&self, bytes: &[u8]) -> Vec<u8> {
match self {
Self::SHA256 => ring_digest(&digest::SHA256, bytes).as_ref().to_owned(),
}
}

pub(super) fn header_name(&self) -> &'static str {
match self {
Self::SHA256 => "x-amz-checksum-sha256",
}
}
}

impl TryFrom<&String> for Checksum {
type Error = ();

fn try_from(value: &String) -> Result<Self, Self::Error> {
match value.as_str() {
"sha256" => Ok(Self::SHA256),
_ => Err(()),
}
}
}
24 changes: 22 additions & 2 deletions object_store/src/aws/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::aws::checksum::Checksum;
use crate::aws::credential::{AwsCredential, CredentialExt, CredentialProvider};
use crate::aws::STRICT_PATH_ENCODE_SET;
use crate::client::pagination::stream_paginated;
Expand All @@ -26,6 +27,8 @@ use crate::{
BoxStream, ClientOptions, ListResult, MultipartId, ObjectMeta, Path, Result,
RetryConfig, StreamExt,
};
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use bytes::{Buf, Bytes};
use chrono::{DateTime, Utc};
use percent_encoding::{utf8_percent_encode, PercentEncode};
Expand Down Expand Up @@ -205,6 +208,7 @@ pub struct S3Config {
pub retry_config: RetryConfig,
pub client_options: ClientOptions,
pub sign_payload: bool,
pub checksum: Option<Checksum>,
}

impl S3Config {
Expand Down Expand Up @@ -262,6 +266,7 @@ impl S3Client {
&self.config.region,
"s3",
self.config.sign_payload,
None,
)
.send_retry(&self.config.retry_config)
.await
Expand All @@ -281,10 +286,19 @@ impl S3Client {
) -> Result<Response> {
let credential = self.get_credential().await?;
let url = self.config.path_url(path);

let mut builder = self.client.request(Method::PUT, url);
let mut payload_sha256 = None;

if let Some(bytes) = bytes {
builder = builder.body(bytes)
if let Some(checksum) = self.config().checksum {
let digest = checksum.digest(&bytes);
builder = builder
.header(checksum.header_name(), BASE64_STANDARD.encode(&digest));
if checksum == Checksum::SHA256 {
payload_sha256 = Some(digest);
}
}
builder = builder.body(bytes);
}

if let Some(value) = self.config().client_options.get_content_type(path) {
Expand All @@ -298,6 +312,7 @@ impl S3Client {
&self.config.region,
"s3",
self.config.sign_payload,
payload_sha256,
)
.send_retry(&self.config.retry_config)
.await
Expand Down Expand Up @@ -325,6 +340,7 @@ impl S3Client {
&self.config.region,
"s3",
self.config.sign_payload,
None,
)
.send_retry(&self.config.retry_config)
.await
Expand All @@ -349,6 +365,7 @@ impl S3Client {
&self.config.region,
"s3",
self.config.sign_payload,
None,
)
.send_retry(&self.config.retry_config)
.await
Expand Down Expand Up @@ -395,6 +412,7 @@ impl S3Client {
&self.config.region,
"s3",
self.config.sign_payload,
None,
)
.send_retry(&self.config.retry_config)
.await
Expand Down Expand Up @@ -438,6 +456,7 @@ impl S3Client {
&self.config.region,
"s3",
self.config.sign_payload,
None,
)
.send_retry(&self.config.retry_config)
.await
Expand Down Expand Up @@ -482,6 +501,7 @@ impl S3Client {
&self.config.region,
"s3",
self.config.sign_payload,
None,
)
.send_retry(&self.config.retry_config)
.await
Expand Down
22 changes: 14 additions & 8 deletions object_store/src/aws/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ const AUTH_HEADER: &str = "authorization";
const ALL_HEADERS: &[&str; 4] = &[DATE_HEADER, HASH_HEADER, TOKEN_HEADER, AUTH_HEADER];

impl<'a> RequestSigner<'a> {
fn sign(&self, request: &mut Request) {
fn sign(&self, request: &mut Request, pre_calculated_digest: Option<Vec<u8>>) {
if let Some(ref token) = self.credential.token {
let token_val = HeaderValue::from_str(token).unwrap();
request.headers_mut().insert(TOKEN_HEADER, token_val);
Expand All @@ -101,9 +101,13 @@ impl<'a> RequestSigner<'a> {
request.headers_mut().insert(DATE_HEADER, date_val);

let digest = if self.sign_payload {
match request.body() {
None => EMPTY_SHA256_HASH.to_string(),
Some(body) => hex_digest(body.as_bytes().unwrap()),
if let Some(digest) = pre_calculated_digest {
hex_encode(&digest)
} else {
match request.body() {
None => EMPTY_SHA256_HASH.to_string(),
Some(body) => hex_digest(body.as_bytes().unwrap()),
}
}
} else {
UNSIGNED_PAYLOAD_LITERAL.to_string()
Expand Down Expand Up @@ -165,6 +169,7 @@ pub trait CredentialExt {
region: &str,
service: &str,
sign_payload: bool,
payload_sha256: Option<Vec<u8>>,
) -> Self;
}

Expand All @@ -175,6 +180,7 @@ impl CredentialExt for RequestBuilder {
region: &str,
service: &str,
sign_payload: bool,
payload_sha256: Option<Vec<u8>>,
) -> Self {
// Hack around lack of access to underlying request
// https://github.com/seanmonstar/reqwest/issues/1212
Expand All @@ -193,7 +199,7 @@ impl CredentialExt for RequestBuilder {
sign_payload,
};

signer.sign(&mut request);
signer.sign(&mut request, payload_sha256);

for header in ALL_HEADERS {
if let Some(val) = request.headers_mut().remove(*header) {
Expand Down Expand Up @@ -627,7 +633,7 @@ mod tests {
sign_payload: true,
};

signer.sign(&mut request);
signer.sign(&mut request, None);
assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=a3c787a7ed37f7fdfbfd2d7056a3d7c9d85e6d52a2bfbec73793c0be6e7862d4")
}

Expand Down Expand Up @@ -665,7 +671,7 @@ mod tests {
sign_payload: false,
};

signer.sign(&mut request);
signer.sign(&mut request, None);
assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=653c3d8ea261fd826207df58bc2bb69fbb5003e9eb3c0ef06e4a51f2a81d8699")
}

Expand Down Expand Up @@ -702,7 +708,7 @@ mod tests {
sign_payload: true,
};

signer.sign(&mut request);
signer.sign(&mut request, None);
assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=H20ABqCkLZID4rLe/20220809/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=9ebf2f92872066c99ac94e573b4e1b80f4dbb8a32b1e8e23178318746e7d1b4d")
}

Expand Down
37 changes: 37 additions & 0 deletions object_store/src/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use tokio::io::AsyncWrite;
use tracing::info;
use url::Url;

pub use crate::aws::checksum::Checksum;
use crate::aws::client::{S3Client, S3Config};
use crate::aws::credential::{
AwsCredential, CredentialProvider, InstanceCredentialProvider,
Expand All @@ -59,6 +60,7 @@ use crate::{
Result, RetryConfig, StreamExt,
};

mod checksum;
mod client;
mod credential;

Expand Down Expand Up @@ -101,6 +103,9 @@ enum Error {
source: std::num::ParseIntError,
},

#[snafu(display("Invalid Checksum algorithm"))]
InvalidChecksumAlgorithm,

#[snafu(display("Missing region"))]
MissingRegion,

Expand Down Expand Up @@ -386,6 +391,7 @@ pub struct AmazonS3Builder {
imdsv1_fallback: bool,
virtual_hosted_style_request: bool,
unsigned_payload: bool,
checksum_algorithm: Option<Checksum>,
metadata_endpoint: Option<String>,
profile: Option<String>,
client_options: ClientOptions,
Expand Down Expand Up @@ -514,6 +520,11 @@ pub enum AmazonS3ConfigKey {
/// - `unsigned_payload`
UnsignedPayload,

/// Set the checksum algorithm for this client
///
/// See [`AmazonS3Builder::with_checksum_algorithm`]
Checksum,

/// Set the instance metadata endpoint
///
/// See [`AmazonS3Builder::with_metadata_endpoint`] for details.
Expand Down Expand Up @@ -546,6 +557,7 @@ impl AsRef<str> for AmazonS3ConfigKey {
Self::MetadataEndpoint => "aws_metadata_endpoint",
Self::Profile => "aws_profile",
Self::UnsignedPayload => "aws_unsigned_payload",
Self::Checksum => "aws_checksum_algorithm",
}
}
}
Expand Down Expand Up @@ -575,6 +587,7 @@ impl FromStr for AmazonS3ConfigKey {
"aws_imdsv1_fallback" | "imdsv1_fallback" => Ok(Self::ImdsV1Fallback),
"aws_metadata_endpoint" | "metadata_endpoint" => Ok(Self::MetadataEndpoint),
"aws_unsigned_payload" | "unsigned_payload" => Ok(Self::UnsignedPayload),
"aws_checksum_algorithm" | "checksum_algorithm" => Ok(Self::Checksum),
_ => Err(Error::UnknownConfigurationKey { key: s.into() }.into()),
}
}
Expand Down Expand Up @@ -694,6 +707,11 @@ impl AmazonS3Builder {
AmazonS3ConfigKey::UnsignedPayload => {
self.unsigned_payload = str_is_truthy(&value.into())
}
AmazonS3ConfigKey::Checksum => {
let algorithm = Checksum::try_from(&value.into())
.map_err(|_| Error::InvalidChecksumAlgorithm)?;
self.checksum_algorithm = Some(algorithm)
}
};
Ok(self)
}
Expand Down Expand Up @@ -846,6 +864,14 @@ impl AmazonS3Builder {
self
}

/// Sets the [checksum algorithm] which has to be used for object integrity check during upload.
///
/// [checksum algorithm]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html
pub fn with_checksum_algorithm(mut self, checksum_algorithm: Checksum) -> Self {
self.checksum_algorithm = Some(checksum_algorithm);
self
}

/// Set the [instance metadata endpoint](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html),
/// used primarily within AWS EC2.
///
Expand Down Expand Up @@ -992,6 +1018,7 @@ impl AmazonS3Builder {
retry_config: self.retry_config,
client_options: self.client_options,
sign_payload: !self.unsigned_payload,
checksum: self.checksum_algorithm,
};

let client = Arc::new(S3Client::new(config)?);
Expand Down Expand Up @@ -1151,6 +1178,7 @@ mod tests {
&container_creds_relative_uri,
);
env::set_var("AWS_UNSIGNED_PAYLOAD", "true");
env::set_var("AWS_CHECKSUM_ALGORITHM", "sha256");

let builder = AmazonS3Builder::from_env();
assert_eq!(builder.access_key_id.unwrap(), aws_access_key_id.as_str());
Expand All @@ -1164,6 +1192,7 @@ mod tests {
assert_eq!(builder.token.unwrap(), aws_session_token);
let metadata_uri = format!("{METADATA_ENDPOINT}{container_creds_relative_uri}");
assert_eq!(builder.metadata_endpoint.unwrap(), metadata_uri);
assert_eq!(builder.checksum_algorithm.unwrap(), Checksum::SHA256);
assert!(builder.unsigned_payload);
}

Expand All @@ -1181,6 +1210,7 @@ mod tests {
("aws_endpoint", aws_endpoint.clone()),
("aws_session_token", aws_session_token.clone()),
("aws_unsigned_payload", "true".to_string()),
("aws_checksum_algorithm", "sha256".to_string()),
]);

let builder = AmazonS3Builder::new()
Expand All @@ -1193,6 +1223,7 @@ mod tests {
assert_eq!(builder.region.unwrap(), aws_default_region);
assert_eq!(builder.endpoint.unwrap(), aws_endpoint);
assert_eq!(builder.token.unwrap(), aws_session_token);
assert_eq!(builder.checksum_algorithm.unwrap(), Checksum::SHA256);
assert!(builder.unsigned_payload);
}

Expand Down Expand Up @@ -1256,6 +1287,12 @@ mod tests {
let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://"));
let integration = config.build().unwrap();
put_get_delete_list_opts(&integration, is_local).await;

// run integration test with checksum set to sha256
let config = maybe_skip_integration!().with_checksum_algorithm(Checksum::SHA256);
let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://"));
let integration = config.build().unwrap();
put_get_delete_list_opts(&integration, is_local).await;
}

#[tokio::test]
Expand Down

0 comments on commit 90cb00d

Please sign in to comment.