Skip to content

Commit

Permalink
feat(external_services): adds encrypt function for KMS (#3111)
Browse files Browse the repository at this point in the history
Co-authored-by: hyperswitch-bot[bot] <148525504+hyperswitch-bot[bot]@users.noreply.github.com>
  • Loading branch information
prajjwalkumar17 and hyperswitch-bot[bot] authored Dec 13, 2023
1 parent 6e82b0b commit bca7cdb
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 3 deletions.
97 changes: 96 additions & 1 deletion crates/external_services/src/kms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl KmsClient {
// Logging using `Debug` representation of the error as the `Display`
// representation does not hold sufficient information.
logger::error!(kms_sdk_error=?error, "Failed to KMS decrypt data");
metrics::AWS_KMS_FAILURES.add(&metrics::CONTEXT, 1, &[]);
metrics::AWS_KMS_DECRYPTION_FAILURES.add(&metrics::CONTEXT, 1, &[]);
error
})
.into_report()
Expand All @@ -96,11 +96,51 @@ impl KmsClient {

Ok(output)
}

/// Encrypts the provided String data using the AWS KMS SDK. We assume that
/// the SDK has the values required to interact with the AWS KMS APIs (`AWS_ACCESS_KEY_ID` and
/// `AWS_SECRET_ACCESS_KEY`) either set in environment variables, or that the SDK is running in
/// a machine that is able to assume an IAM role.
pub async fn encrypt(&self, data: impl AsRef<[u8]>) -> CustomResult<String, KmsError> {
let start = Instant::now();
let plaintext_blob = Blob::new(data.as_ref());

let encrypted_output = self
.inner_client
.encrypt()
.key_id(&self.key_id)
.plaintext(plaintext_blob)
.send()
.await
.map_err(|error| {
// Logging using `Debug` representation of the error as the `Display`
// representation does not hold sufficient information.
logger::error!(kms_sdk_error=?error, "Failed to KMS encrypt data");
metrics::AWS_KMS_ENCRYPTION_FAILURES.add(&metrics::CONTEXT, 1, &[]);
error
})
.into_report()
.change_context(KmsError::EncryptionFailed)?;

let output = encrypted_output
.ciphertext_blob
.ok_or(KmsError::MissingCiphertextEncryptionOutput)
.into_report()
.map(|blob| consts::BASE64_ENGINE.encode(blob.into_inner()))?;
let time_taken = start.elapsed();
metrics::AWS_KMS_ENCRYPT_TIME.record(&metrics::CONTEXT, time_taken.as_secs_f64(), &[]);

Ok(output)
}
}

/// Errors that could occur during KMS operations.
#[derive(Debug, thiserror::Error)]
pub enum KmsError {
/// An error occurred when base64 encoding input data.
#[error("Failed to base64 encode input data")]
Base64EncodingFailed,

/// An error occurred when base64 decoding input data.
#[error("Failed to base64 decode input data")]
Base64DecodingFailed,
Expand All @@ -109,10 +149,18 @@ pub enum KmsError {
#[error("Failed to KMS decrypt input data")]
DecryptionFailed,

/// An error occurred when KMS encrypting input data.
#[error("Failed to KMS encrypt input data")]
EncryptionFailed,

/// The KMS decrypted output does not include a plaintext output.
#[error("Missing plaintext KMS decryption output")]
MissingPlaintextDecryptionOutput,

/// The KMS encrypted output does not include a ciphertext output.
#[error("Missing ciphertext KMS encryption output")]
MissingCiphertextEncryptionOutput,

/// An error occurred UTF-8 decoding KMS decrypted output.
#[error("Failed to UTF-8 decode decryption output")]
Utf8DecodingFailed,
Expand Down Expand Up @@ -147,3 +195,50 @@ impl common_utils::ext_traits::ConfigExt for KmsValue {
self.0.peek().is_empty_after_trim()
}
}

#[cfg(test)]
mod tests {
#![allow(clippy::expect_used)]
#[tokio::test]
async fn check_kms_encryption() {
std::env::set_var("AWS_SECRET_ACCESS_KEY", "YOUR SECRET ACCESS KEY");
std::env::set_var("AWS_ACCESS_KEY_ID", "YOUR AWS ACCESS KEY ID");
use super::*;
let config = KmsConfig {
key_id: "YOUR KMS KEY ID".to_string(),
region: "AWS REGION".to_string(),
};

let data = "hello".to_string();
let binding = data.as_bytes();
let kms_encrypted_fingerprint = KmsClient::new(&config)
.await
.encrypt(binding)
.await
.expect("kms encryption failed");

println!("{}", kms_encrypted_fingerprint);
}

#[tokio::test]
async fn check_kms_decrypt() {
std::env::set_var("AWS_SECRET_ACCESS_KEY", "YOUR SECRET ACCESS KEY");
std::env::set_var("AWS_ACCESS_KEY_ID", "YOUR AWS ACCESS KEY ID");
use super::*;
let config = KmsConfig {
key_id: "YOUR KMS KEY ID".to_string(),
region: "AWS REGION".to_string(),
};

// Should decrypt to hello
let data = "KMS ENCRYPTED CIPHER".to_string();
let binding = data.as_bytes();
let kms_encrypted_fingerprint = KmsClient::new(&config)
.await
.decrypt(binding)
.await
.expect("kms decryption failed");

println!("{}", kms_encrypted_fingerprint);
}
}
6 changes: 5 additions & 1 deletion crates/external_services/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ pub mod metrics {
global_meter!(GLOBAL_METER, "EXTERNAL_SERVICES");

#[cfg(feature = "kms")]
counter_metric!(AWS_KMS_FAILURES, GLOBAL_METER); // No. of AWS KMS API failures
counter_metric!(AWS_KMS_DECRYPTION_FAILURES, GLOBAL_METER); // No. of AWS KMS Decryption failures
#[cfg(feature = "kms")]
counter_metric!(AWS_KMS_ENCRYPTION_FAILURES, GLOBAL_METER); // No. of AWS KMS Encryption failures

#[cfg(feature = "kms")]
histogram_metric!(AWS_KMS_DECRYPT_TIME, GLOBAL_METER); // Histogram for KMS decryption time (in sec)
#[cfg(feature = "kms")]
histogram_metric!(AWS_KMS_ENCRYPT_TIME, GLOBAL_METER); // Histogram for KMS encryption time (in sec)
}
4 changes: 3 additions & 1 deletion crates/router/src/routes/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ global_meter!(GLOBAL_METER, "ROUTER_API");
counter_metric!(HEALTH_METRIC, GLOBAL_METER); // No. of health API hits
counter_metric!(KV_MISS, GLOBAL_METER); // No. of KV misses
#[cfg(feature = "kms")]
counter_metric!(AWS_KMS_FAILURES, GLOBAL_METER); // No. of AWS KMS API failures
counter_metric!(AWS_KMS_ENCRYPTION_FAILURES, GLOBAL_METER); // No. of AWS KMS Encryption failures
#[cfg(feature = "kms")]
counter_metric!(AWS_KMS_DECRYPTION_FAILURES, GLOBAL_METER); // No. of AWS KMS Decryption failures

// API Level Metrics
counter_metric!(REQUESTS_RECEIVED, GLOBAL_METER);
Expand Down

0 comments on commit bca7cdb

Please sign in to comment.