Skip to content

Commit

Permalink
Add tests for user-provided checksum case
Browse files Browse the repository at this point in the history
  • Loading branch information
landonxjames committed Nov 11, 2024
1 parent 37bc5ea commit 6f82920
Showing 1 changed file with 155 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
import software.amazon.smithy.rust.codegen.core.util.dq

internal class HttpChecksumTest {
companion object {
Expand Down Expand Up @@ -70,6 +71,24 @@ internal class HttpChecksumTest {
@httpHeader("x-amz-response-validation-mode")
validationMode: ValidationMode
@httpHeader("x-amz-checksum-crc32")
ChecksumCRC32: String
@httpHeader("x-amz-checksum-crc32c")
ChecksumCRC32C: String
@httpHeader("x-amz-checksum-crc64nvme")
ChecksumCRC64Nvme: String
@httpHeader("x-amz-checksum-sha1")
ChecksumSHA1: String
@httpHeader("x-amz-checksum-sha256")
ChecksumSHA256: String
@httpHeader("x-amz-checksum-foo")
ChecksumFoo: String
@httpPayload
@required
body: Blob
Expand Down Expand Up @@ -140,6 +159,8 @@ internal class HttpChecksumTest {
checksumResponseFailTests.map { createResponseChecksumValidationFailureTest(it, context) }.join("\n")
val checksumStreamingRequestTestWritables =
streamingRequestTests.map { createStreamingRequestChecksumCalculationTest(it, context) }.join("\n")
val userProvidedChecksumTestWritables =
userProvidedChecksumTests.map { createUserProvidedChecksumsTest(it, context) }.join("\n")
val miscTests = createMiscellaneousTests(context)

// Shared imports for all test types
Expand Down Expand Up @@ -192,6 +213,10 @@ internal class HttpChecksumTest {
testBase.plus(checksumStreamingRequestTestWritables)()
}

rustCrate.integrationTest("user_provided_checksums") {
testBase.plus(userProvidedChecksumTestWritables)()
}

rustCrate.integrationTest("misc_tests") {
testBase.plus(miscTests)()
}
Expand Down Expand Up @@ -357,7 +382,7 @@ internal class HttpChecksumTest {
"""
//${testDef.docs}
##[::tokio::test]
async fn ${algoLower}_response_checksums_work() {
async fn ${algoLower}_response_checksums_fails_correctly() {
let (http_client, _rx) = #{capture_request}(Some(
http::Response::builder()
.header("x-amz-checksum-$algoLower", "${testDef.checksumHeaderValue}")
Expand Down Expand Up @@ -450,6 +475,67 @@ internal class HttpChecksumTest {
}
}

/**
* Generate tests for the case where a user provides a checksum
*/
private fun createUserProvidedChecksumsTest(
testDef: UserProvidedChecksumTest,
context: ClientCodegenContext,
): Writable {
val rc = context.runtimeConfig
val moduleName = context.moduleUseName()
val algoLower = testDef.checksumAlgorithm.lowercase()
// We treat the c after crc32c and the nvme after crc64nvme as separate words
// so this quick map helps us find the field to set
val algoFieldNames =
mapOf(
"crc32" to "checksum_crc32",
"crc32c" to "checksum_crc32_c",
"crc64nvme" to "checksum_crc64_nvme",
"foo" to "checksum_foo",
"sha1" to "checksum_sha1",
"sha256" to "checksum_sha256",
)

return writable {
rustTemplate(
"""
//${testDef.docs}
##[#{tokio}::test]
async fn user_provided_${algoLower}_request_checksum_works() {
let (http_client, rx) = #{capture_request}(None);
let config = $moduleName::Config::builder()
.region(Region::from_static("doesntmatter"))
.with_test_defaults()
.http_client(http_client)
.build();
let client = $moduleName::Client::from_conf(config);
let _ = client.http_checksum_operation()
.body(Blob::new(b"${testDef.requestPayload}"))
.${algoFieldNames.get(algoLower)}(${testDef.checksumValue.dq()})
.send()
.await;
let request = rx.expect_request();
let ${algoLower}_header = request.headers()
.get("x-amz-checksum-$algoLower")
.expect("x-amz-checksum-$algoLower header should exist");
assert_eq!(${algoLower}_header, "${testDef.expectedHeaderValue}");
let algo_header = request.headers()
.get("x-amz-request-algorithm");
assert!(algo_header.is_none());
}
""",
*preludeScope,
"tokio" to CargoDependency.Tokio.toType(),
"capture_request" to RuntimeType.captureRequest(rc),
)
}
}

/**
* Generate miscellaneous tests, currently mostly focused on the inclusion of the checksum config metrics in the
* user-agent header
Expand Down Expand Up @@ -781,3 +867,71 @@ val checksumResponseFailTests =
"ZOyIygCyaOW6GjVnihtTFtIS9PNmskdyMlNKiuyjfzw=",
),
)

data class UserProvidedChecksumTest(
val docs: String,
val requestPayload: String,
val checksumAlgorithm: String,
val checksumValue: String,
val expectedHeaderName: String,
val expectedHeaderValue: String,
val forbidHeaderName: String,
)

val userProvidedChecksumTests =
listOf(
UserProvidedChecksumTest(
"CRC32 checksum provided by user.",
"Hello world",
"Crc32",
"i9aeUg==",
"x-amz-checksum-crc32",
"i9aeUg==",
"x-amz-request-algorithm",
),
UserProvidedChecksumTest(
"CRC32C checksum provided by user.",
"Hello world",
"Crc32C",
"crUfeA==",
"x-amz-checksum-crc32c",
"crUfeA==",
"x-amz-request-algorithm",
),
UserProvidedChecksumTest(
"CRC64NVME checksum provided by user.",
"Hello world",
"Crc64Nvme",
"OOJZ0D8xKts=",
"x-amz-checksum-crc64nvme",
"OOJZ0D8xKts=",
"x-amz-request-algorithm",
),
UserProvidedChecksumTest(
"SHA1 checksum provided by user.",
"Hello world",
"Sha1",
"e1AsOh9IyGCa4hLN+2Od7jlnP14=",
"x-amz-checksum-sha1",
"e1AsOh9IyGCa4hLN+2Od7jlnP14=",
"x-amz-request-algorithm",
),
UserProvidedChecksumTest(
"SHA256 checksum provided by user.",
"Hello world",
"Sha256",
"ZOyIygCyaOW6GjVnihtTFtIS9PNmskdyMlNKiuyjfzw=",
"x-amz-checksum-sha256",
"ZOyIygCyaOW6GjVnihtTFtIS9PNmskdyMlNKiuyjfzw=",
"x-amz-request-algorithm",
),
UserProvidedChecksumTest(
"Forwards compatibility, unmodeled checksum provided by user.",
"Hello world",
"Foo",
"This-is-not-a-real-checksum",
"x-amz-checksum-foo",
"This-is-not-a-real-checksum",
"x-amz-request-algorithm",
),
)

0 comments on commit 6f82920

Please sign in to comment.