diff --git a/.changelog/1738663351.md b/.changelog/1738663351.md new file mode 100644 index 0000000000..f496e9f74a --- /dev/null +++ b/.changelog/1738663351.md @@ -0,0 +1,11 @@ +--- +applies_to: +- server +authors: +- drganjoo +references: [] +breaking: false +new_feature: true +bug_fix: false +--- +Enhanced UTF-8 handling: When `replaceInvalidUtf8` codegen flag is enabled, invalid UTF-8 sequences are now automatically replaced with the Unicode replacement character (U+FFFD) instead of causing errors diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 8b8245287b..6433aee090 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -23,7 +23,6 @@ import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.SparseTrait import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.escape @@ -102,30 +101,30 @@ class JsonParserGenerator( ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false) }, private val customizations: List = listOf(), + smithyJsonWithFeatureFlag: RuntimeType = RuntimeType.smithyJson(codegenContext.runtimeConfig), ) : StructuredDataParserGenerator { private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider private val runtimeConfig = codegenContext.runtimeConfig private val codegenTarget = codegenContext.target - private val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType() private val protocolFunctions = ProtocolFunctions(codegenContext) private val builderInstantiator = codegenContext.builderInstantiator() private val codegenScope = arrayOf( - "Error" to smithyJson.resolve("deserialize::error::DeserializeError"), - "expect_blob_or_null" to smithyJson.resolve("deserialize::token::expect_blob_or_null"), - "expect_bool_or_null" to smithyJson.resolve("deserialize::token::expect_bool_or_null"), - "expect_document" to smithyJson.resolve("deserialize::token::expect_document"), - "expect_number_or_null" to smithyJson.resolve("deserialize::token::expect_number_or_null"), - "expect_start_array" to smithyJson.resolve("deserialize::token::expect_start_array"), - "expect_start_object" to smithyJson.resolve("deserialize::token::expect_start_object"), - "expect_string_or_null" to smithyJson.resolve("deserialize::token::expect_string_or_null"), - "expect_timestamp_or_null" to smithyJson.resolve("deserialize::token::expect_timestamp_or_null"), - "json_token_iter" to smithyJson.resolve("deserialize::json_token_iter"), + "Error" to smithyJsonWithFeatureFlag.resolve("deserialize::error::DeserializeError"), + "expect_blob_or_null" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_blob_or_null"), + "expect_bool_or_null" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_bool_or_null"), + "expect_document" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_document"), + "expect_number_or_null" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_number_or_null"), + "expect_start_array" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_start_array"), + "expect_start_object" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_start_object"), + "expect_string_or_null" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_string_or_null"), + "expect_timestamp_or_null" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_timestamp_or_null"), + "json_token_iter" to smithyJsonWithFeatureFlag.resolve("deserialize::json_token_iter"), "Peekable" to RuntimeType.std.resolve("iter::Peekable"), - "skip_value" to smithyJson.resolve("deserialize::token::skip_value"), - "skip_to_end" to smithyJson.resolve("deserialize::token::skip_to_end"), - "Token" to smithyJson.resolve("deserialize::Token"), + "skip_value" to smithyJsonWithFeatureFlag.resolve("deserialize::token::skip_value"), + "skip_to_end" to smithyJsonWithFeatureFlag.resolve("deserialize::token::skip_to_end"), + "Token" to smithyJsonWithFeatureFlag.resolve("deserialize::Token"), "or_empty" to orEmptyJson(), *preludeScope, ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/CodegenIntegrationTest.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/CodegenIntegrationTest.kt index 817af1e09a..282c663482 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/CodegenIntegrationTest.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/CodegenIntegrationTest.kt @@ -5,9 +5,19 @@ package software.amazon.smithy.rust.codegen.core.testutil +import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait +import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait +import software.amazon.smithy.aws.traits.protocols.RestJson1Trait +import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.build.PluginContext import software.amazon.smithy.model.Model import software.amazon.smithy.model.node.ObjectNode +import software.amazon.smithy.model.node.ToNode +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.AbstractTrait +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.util.runCommand import java.io.File @@ -50,87 +60,78 @@ data class IntegrationTestParams( sealed class AdditionalSettings { abstract fun toObjectNode(): ObjectNode - abstract class CoreAdditionalSettings protected constructor(val settings: List) : AdditionalSettings() { - override fun toObjectNode(): ObjectNode { - val merged = - settings.map { it.toObjectNode() } - .reduce { acc, next -> acc.merge(next) } - - return ObjectNode.builder() - .withMember("codegen", merged) + companion object { + private fun Map.toCodegenObjectNode(): ObjectNode = + ObjectNode.builder() + .withMember( + "codegen", + ObjectNode.builder().apply { + forEach { (key, value) -> + when (value) { + is Boolean -> withMember(key, value) + is Number -> withMember(key, value) + is String -> withMember(key, value) + is ToNode -> withMember(key, value) + else -> throw IllegalArgumentException("Unsupported type for key $key: ${value::class}") + } + } + }.build(), + ) .build() - } + } - abstract class Builder : AdditionalSettings() { - protected val settings = mutableListOf() + abstract class CoreAdditionalSettings protected constructor( + private val settings: Map, + ) : AdditionalSettings() { + override fun toObjectNode(): ObjectNode = settings.toCodegenObjectNode() - fun generateCodegenComments(debugMode: Boolean = true): Builder { - settings.add(GenerateCodegenComments(debugMode)) - return this - } - - abstract fun build(): T + abstract class Builder : AdditionalSettings() { + protected val settings = mutableMapOf() - override fun toObjectNode(): ObjectNode = build().toObjectNode() - } + fun generateCodegenComments(debugMode: Boolean = true) = + apply { + settings["debugMode"] = debugMode + } - // Core settings that are common to both Servers and Clients should be defined here. - data class GenerateCodegenComments(val debugMode: Boolean) : AdditionalSettings() { - override fun toObjectNode(): ObjectNode = - ObjectNode.builder() - .withMember("debugMode", debugMode) - .build() + override fun toObjectNode(): ObjectNode = settings.toCodegenObjectNode() } } } -class ClientAdditionalSettings private constructor(settings: List) : - AdditionalSettings.CoreAdditionalSettings(settings) { - class Builder : CoreAdditionalSettings.Builder() { - override fun build(): ClientAdditionalSettings = ClientAdditionalSettings(settings) - } - - // Additional settings that are specific to client generation should be defined here. - - companion object { - fun builder() = Builder() - } - } - -class ServerAdditionalSettings private constructor(settings: List) : - AdditionalSettings.CoreAdditionalSettings(settings) { - class Builder : CoreAdditionalSettings.Builder() { - fun publicConstrainedTypes(enabled: Boolean = true): Builder { - settings.add(PublicConstrainedTypes(enabled)) - return this +class ServerAdditionalSettings private constructor( + settings: Map, +) : AdditionalSettings.CoreAdditionalSettings(settings) { + class Builder : CoreAdditionalSettings.Builder() { + fun publicConstrainedTypes(enabled: Boolean = true) = + apply { + settings["publicConstrainedTypes"] = enabled } - fun addValidationExceptionToConstrainedOperations(enabled: Boolean = true): Builder { - settings.add(AddValidationExceptionToConstrainedOperations(enabled)) - return this + fun addValidationExceptionToConstrainedOperations(enabled: Boolean = true) = + apply { + settings["addValidationExceptionToConstrainedOperations"] = enabled } - override fun build(): ServerAdditionalSettings = ServerAdditionalSettings(settings) - } + fun replaceInvalidUtf8(enabled: Boolean = true) = + apply { + settings["replaceInvalidUtf8"] = enabled + } + } - private data class PublicConstrainedTypes(val enabled: Boolean) : AdditionalSettings() { - override fun toObjectNode(): ObjectNode = - ObjectNode.builder() - .withMember("publicConstrainedTypes", enabled) - .build() - } + companion object { + fun builder() = Builder() + } +} - private data class AddValidationExceptionToConstrainedOperations(val enabled: Boolean) : AdditionalSettings() { - override fun toObjectNode(): ObjectNode = - ObjectNode.builder() - .withMember("addValidationExceptionToConstrainedOperations", enabled) - .build() - } +class ClientAdditionalSettings private constructor( + settings: Map, +) : AdditionalSettings.CoreAdditionalSettings(settings) { + class Builder : CoreAdditionalSettings.Builder() - companion object { - fun builder() = Builder() - } + companion object { + fun builder() = Builder() } +} /** * Run cargo test on a true, end-to-end, codegen product of a given model. @@ -161,3 +162,128 @@ fun codegenIntegrationTest( logger.fine(out.toString()) return testDir } + +/** + * Metadata associated with a protocol that provides additional information needed for testing. + * + * @property protocol The protocol enum value this metadata is associated with + * @property contentType The HTTP Content-Type header value associated with this protocol. + */ +data class ProtocolMetadata( + val protocol: ModelProtocol, + val contentType: String, +) + +/** + * Represents the supported protocol traits in Smithy models. + * + * @property trait The Smithy trait instance with which the service shape must be annotated. + */ +enum class ModelProtocol(val trait: AbstractTrait) { + AwsJson10(AwsJson1_0Trait.builder().build()), + AwsJson11(AwsJson1_1Trait.builder().build()), + RestJson(RestJson1Trait.builder().build()), + RestXml(RestXmlTrait.builder().build()), + Rpcv2Cbor(Rpcv2CborTrait.builder().build()), + ; + + // Create metadata after enum is initialized + val metadata: ProtocolMetadata by lazy { + when (this) { + AwsJson10 -> ProtocolMetadata(this, "application/x-amz-json-1.0") + AwsJson11 -> ProtocolMetadata(this, "application/x-amz-json-1.1") + RestJson -> ProtocolMetadata(this, "application/json") + RestXml -> ProtocolMetadata(this, "application/xml") + Rpcv2Cbor -> ProtocolMetadata(this, "application/cbor") + } + } + + companion object { + private val TRAIT_IDS = values().map { it.trait.toShapeId() }.toSet() + val ALL: Set = values().toSet() + + fun getTraitIds() = TRAIT_IDS + } +} + +/** + * Removes all existing protocol traits annotated on the given service, + * then sets the provided `protocol` as the sole protocol trait for the service. + */ +fun Model.replaceProtocolTraitOnServerShapeId( + serviceShapeId: ShapeId, + modelProtocol: ModelProtocol, +): Model { + val serviceShape = this.expectShape(serviceShapeId, ServiceShape::class.java) + return replaceProtocolTraitOnServiceShape(serviceShape, modelProtocol) +} + +/** + * Removes all existing protocol traits annotated on the given service shape, + * then sets the provided `protocol` as the sole protocol trait for the service. + */ +fun Model.replaceProtocolTraitOnServiceShape( + serviceShape: ServiceShape, + modelProtocol: ModelProtocol, +): Model { + val serviceBuilder = serviceShape.toBuilder() + ModelProtocol.getTraitIds().forEach { traitId -> + serviceBuilder.removeTrait(traitId) + } + val service = serviceBuilder.addTrait(modelProtocol.trait).build() + return ModelTransformer.create().replaceShapes(this, listOf(service)) +} + +/** + * Processes a Smithy model string by applying different protocol traits and invoking the tests block on the model. + * For each protocol, this function: + * 1. Parses the Smithy model string + * 2. Replaces any existing protocol traits on service shapes with the specified protocol + * 3. Runs the provided test with the transformed model and protocol metadata + * + * @param protocolTraitIds Set of protocols to test against + * @param test Function that receives the transformed model and protocol metadata for testing + */ +fun String.forProtocols( + protocolTraitIds: Set, + test: (Model, ProtocolMetadata) -> Unit, +) { + val baseModel = this.asSmithyModel(smithyVersion = "2") + val serviceShapes = baseModel.serviceShapes.toList() + + protocolTraitIds.forEach { protocol -> + val transformedModel = + serviceShapes.fold(baseModel) { acc, shape -> + acc.replaceProtocolTraitOnServiceShape(shape, protocol) + } + test(transformedModel, protocol.metadata) + } +} + +/** + * Convenience overload that accepts vararg protocols instead of a Set. + * + * @param protocols Variable number of protocols to test against + * @param test Function that receives the transformed model and protocol metadata for testing + * @see forProtocols + */ +fun String.forProtocols( + vararg protocols: ModelProtocol, + test: (Model, ProtocolMetadata) -> Unit, +) { + forProtocols(protocols.toSet(), test) +} + +/** + * Tests a Smithy model string against all supported protocols, with optional exclusions. + * + * @param exclude Set of protocols to exclude from testing (default is empty) + * @param test Function that receives the transformed model and protocol metadata for testing + * @see forProtocols + */ +fun String.forAllProtocols( + exclude: Set = emptySet(), + test: (Model, ProtocolMetadata) -> Unit, +) { + forProtocols(ModelProtocol.ALL - exclude, test) +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt index b5949ad09c..a4a7189a5f 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt @@ -97,6 +97,7 @@ data class ServerCodegenConfig( */ val experimentalCustomValidationExceptionWithReasonPleaseDoNotUse: String? = defaultExperimentalCustomValidationExceptionWithReasonPleaseDoNotUse, val addValidationExceptionToConstrainedOperations: Boolean = DEFAULT_ADD_VALIDATION_EXCEPTION_TO_CONSTRAINED_OPERATIONS, + val replaceInvalidUtf8: Boolean = DEFAULT_REPLACE_INVALID_UTF8, ) : CoreCodegenConfig( formatTimeoutSeconds, debugMode, ) { @@ -105,6 +106,7 @@ data class ServerCodegenConfig( private const val DEFAULT_IGNORE_UNSUPPORTED_CONSTRAINTS = false private val defaultExperimentalCustomValidationExceptionWithReasonPleaseDoNotUse = null private const val DEFAULT_ADD_VALIDATION_EXCEPTION_TO_CONSTRAINED_OPERATIONS = false + private const val DEFAULT_REPLACE_INVALID_UTF8 = false fun fromCodegenConfigAndNode( coreCodegenConfig: CoreCodegenConfig, @@ -117,6 +119,7 @@ data class ServerCodegenConfig( ignoreUnsupportedConstraints = node.get().getBooleanMemberOrDefault("ignoreUnsupportedConstraints", DEFAULT_IGNORE_UNSUPPORTED_CONSTRAINTS), experimentalCustomValidationExceptionWithReasonPleaseDoNotUse = node.get().getStringMemberOrDefault("experimentalCustomValidationExceptionWithReasonPleaseDoNotUse", defaultExperimentalCustomValidationExceptionWithReasonPleaseDoNotUse), addValidationExceptionToConstrainedOperations = node.get().getBooleanMemberOrDefault("addValidationExceptionToConstrainedOperations", DEFAULT_ADD_VALIDATION_EXCEPTION_TO_CONSTRAINED_OPERATIONS), + replaceInvalidUtf8 = node.get().getBooleanMemberOrDefault("replaceInvalidUtf8", DEFAULT_REPLACE_INVALID_UTF8), ) } else { ServerCodegenConfig( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 43982b9b3e..ba6b855a9d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -144,6 +144,14 @@ fun jsonParserGenerator( listOf( ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(codegenContext), ) + additionalParserCustomizations, + smithyJsonWithFeatureFlag = + if (codegenContext.settings.codegenConfig.replaceInvalidUtf8) { + CargoDependency.smithyJson(codegenContext.runtimeConfig) + .copy(features = setOf("replace-invalid-utf8")) + .toType() + } else { + RuntimeType.smithyJson(codegenContext.runtimeConfig) + }, ) class ServerAwsJsonProtocol( @@ -244,7 +252,7 @@ class ServerRestJsonProtocol( serverCodegenContext, httpBindingResolver, ::restJsonFieldName, - additionalParserCustomizations, + additionalParserCustomizations ) override fun structuredDataSerializer(): StructuredDataSerializerGenerator = diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt index 36e30230c6..ff0c990a1e 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt @@ -9,10 +9,6 @@ import io.kotest.inspectors.forAll import io.kotest.matchers.ints.shouldBeGreaterThan import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test -import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait -import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait -import software.amazon.smithy.aws.traits.protocols.RestJson1Trait -import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.ListShape @@ -22,35 +18,27 @@ import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.model.traits.AbstractTrait import software.amazon.smithy.model.transform.ModelTransformer -import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.ModelProtocol import software.amazon.smithy.rust.codegen.core.testutil.ServerAdditionalSettings import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.replaceProtocolTraitOnServerShapeId import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider import java.io.File -enum class ModelProtocol(val trait: AbstractTrait) { - AwsJson10(AwsJson1_0Trait.builder().build()), - AwsJson11(AwsJson1_1Trait.builder().build()), - RestJson(RestJson1Trait.builder().build()), - RestXml(RestXmlTrait.builder().build()), - Rpcv2Cbor(Rpcv2CborTrait.builder().build()), -} - /** * Returns the Smithy constraints model from the common repository, with the specified protocol * applied to the service. */ fun loadSmithyConstraintsModelForProtocol(modelProtocol: ModelProtocol): Pair { val (model, serviceShapeId) = loadSmithyConstraintsModel() - return Pair(model.replaceProtocolTrait(serviceShapeId, modelProtocol), serviceShapeId) + return Pair(model.replaceProtocolTraitOnServerShapeId(serviceShapeId, modelProtocol), serviceShapeId) } /** @@ -65,23 +53,6 @@ fun loadSmithyConstraintsModel(): Pair { return Pair(model, serviceShapeId) } -/** - * Removes all existing protocol traits annotated on the given service, - * then sets the provided `protocol` as the sole protocol trait for the service. - */ -fun Model.replaceProtocolTrait( - serviceShapeId: ShapeId, - modelProtocol: ModelProtocol, -): Model { - val serviceBuilder = - this.expectShape(serviceShapeId, ServiceShape::class.java).toBuilder() - for (p in ModelProtocol.values()) { - serviceBuilder.removeTrait(p.trait.toShapeId()) - } - val service = serviceBuilder.addTrait(modelProtocol.trait).build() - return ModelTransformer.create().replaceShapes(this, listOf(service)) -} - fun List.containsAnyShapeId(ids: Collection): Boolean { return ids.any { id -> this.any { shape -> shape == id } } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ReplaceInvalidUtf8Test.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ReplaceInvalidUtf8Test.kt new file mode 100644 index 0000000000..ac6adb46aa --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ReplaceInvalidUtf8Test.kt @@ -0,0 +1,110 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.ModelProtocol +import software.amazon.smithy.rust.codegen.core.testutil.ServerAdditionalSettings +import software.amazon.smithy.rust.codegen.core.testutil.forAllProtocols +import software.amazon.smithy.rust.codegen.core.testutil.forProtocols +import software.amazon.smithy.rust.codegen.core.testutil.testModule +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest + +internal class ReplaceInvalidUtf8Test { + val model = + """ + namespace test + + service SampleService { + operations: [SampleOperation] + } + + @http(uri: "/operation", method: "PUT") + operation SampleOperation { + input := { + x : String + } + } + """ + + @Test + fun `invalid utf8 should be replaced if the codegen flag is set`() { + model.forProtocols(ModelProtocol.AwsJson10, ModelProtocol.AwsJson11, ModelProtocol.RestJson) { model, metadata -> + serverIntegrationTest( + model, + IntegrationTestParams( + additionalSettings = + ServerAdditionalSettings + .builder() + .replaceInvalidUtf8(true) + .toObjectNode(), + ), + ) { _, rustCrate -> + rustCrate.testModule { + rustTemplate( + """ + ##[tokio::test] + async fn test_utf8_replaced() { + let body = r##"{ "x" : "\ud800" }"##; + let request = http::Request::builder() + .method("POST") + .uri("/operation") + .header("content-type", "${metadata.contentType}") + .body(hyper::Body::from(body)) + .expect("failed to build request"); + let result = crate::protocol_serde::shape_sample_operation::de_sample_operation_http_request(request).await; + assert!( + result.is_ok(), + "Invalid utf8 should have been replaced. {result:?}" + ); + assert_eq!( + result.unwrap().x.unwrap(), + "�", + "payload should have been replaced with �." + ); + } + """, + ) + } + } + } + } + + @Test + fun `invalid utf8 should be rejected if the codegen flag is not set`() { + model.forAllProtocols(exclude = setOf(ModelProtocol.RestXml, ModelProtocol.Rpcv2Cbor)) { model, metadata -> + serverIntegrationTest( + model, + ) { _, rustCrate -> + rustCrate.testModule { + rustTemplate( + """ + ##[tokio::test] + async fn test_invalid_utf8_raises_an_error() { + let body = r##"{ "x" : "\ud800" }"##; + let request = http::Request::builder() + .method("POST") + .uri("/operation") + .header("content-type", "${metadata.contentType}") + .body(hyper::Body::from(body)) + .expect("failed to build request"); + let result = crate::protocol_serde::shape_sample_operation::de_sample_operation_http_request(request).await; + assert!( + result.is_err(), + "invalid utf8 characters should not be allowed by default {result:?}" + ); + let error_msg = result.err().unwrap().to_string(); + assert!(error_msg.contains("failed to unescape JSON string")); + } + """, + ) + } + } + } + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt index 0a80a125d6..03f5ce71c7 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/serialize/CborConstraintsIntegrationTest.kt @@ -6,8 +6,8 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize import org.junit.jupiter.api.Test import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.ModelProtocol import software.amazon.smithy.rust.codegen.core.testutil.ServerAdditionalSettings -import software.amazon.smithy.rust.codegen.server.smithy.ModelProtocol import software.amazon.smithy.rust.codegen.server.smithy.loadSmithyConstraintsModelForProtocol import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest diff --git a/rust-runtime/aws-smithy-json/Cargo.toml b/rust-runtime/aws-smithy-json/Cargo.toml index 92105e7448..04e800114c 100644 --- a/rust-runtime/aws-smithy-json/Cargo.toml +++ b/rust-runtime/aws-smithy-json/Cargo.toml @@ -1,7 +1,10 @@ [package] name = "aws-smithy-json" -version = "0.61.2" -authors = ["AWS Rust SDK Team ", "John DiSanti "] +version = "0.62.0" +authors = [ + "AWS Rust SDK Team ", + "John DiSanti ", +] description = "Token streaming JSON parser for smithy-rs." edition = "2021" license = "Apache-2.0" @@ -20,3 +23,6 @@ targets = ["x86_64-unknown-linux-gnu"] cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] rustdoc-args = ["--cfg", "docsrs"] # End of docs.rs metadata + +[features] +replace-invalid-utf8 = [] diff --git a/rust-runtime/aws-smithy-json/src/escape.rs b/rust-runtime/aws-smithy-json/src/escape.rs index d87e57c6da..006eea89aa 100644 --- a/rust-runtime/aws-smithy-json/src/escape.rs +++ b/rust-runtime/aws-smithy-json/src/escape.rs @@ -10,6 +10,7 @@ use std::fmt; enum EscapeErrorKind { ExpectedSurrogatePair(String), InvalidEscapeCharacter(char), + #[cfg(not(feature = "replace-invalid-utf8"))] InvalidSurrogatePair(u16, u16), InvalidUnicodeEscape(String), InvalidUtf8, @@ -36,6 +37,7 @@ impl fmt::Display for EscapeError { ) } InvalidEscapeCharacter(chr) => write!(f, "invalid JSON escape: \\{}", chr), + #[cfg(not(feature = "replace-invalid-utf8"))] InvalidSurrogatePair(high, low) => { write!(f, "invalid surrogate pair: \\u{:04X}\\u{:04X}", high, low) } @@ -182,6 +184,7 @@ fn read_codepoint(rest: &[u8]) -> Result { Ok(u16::from_str_radix(codepoint_str, 16).expect("hex string is valid 16-bit value")) } +#[cfg(not(feature = "replace-invalid-utf8"))] /// Reads JSON Unicode escape sequences (i.e., "\u1234"). Will also read /// an additional codepoint if the first codepoint is the start of a surrogate pair. fn read_unicode_escapes(bytes: &[u8], into: &mut Vec) -> Result { @@ -210,6 +213,35 @@ fn read_unicode_escapes(bytes: &[u8], into: &mut Vec) -> Result) -> Result { + let high = read_codepoint(bytes)?; + let (bytes_read, chr) = if is_utf16_high_surrogate(high) { + match read_codepoint(&bytes[6..]) { + Ok(low) if is_utf16_low_surrogate(low) => { + let codepoint = 0x10000 + (high - 0xD800) as u32 * 0x400 + (low - 0xDC00) as u32; + (12, std::char::from_u32(codepoint)) + } + _ => (6, None), + } + } else { + (6, std::char::from_u32(high as u32)) + }; + + match chr { + Some(chr) => match chr.len_utf8() { + 1 => into.push(chr as u8), + _ => into.extend_from_slice(chr.encode_utf8(&mut [0; 4]).as_bytes()), + }, + None => { + const REPLACEMENT_BYTES: &[u8] = "\u{FFFD}".as_bytes(); + into.extend_from_slice(REPLACEMENT_BYTES) + } + } + + Ok(bytes_read) +} + #[cfg(test)] mod test { use super::escape_string; @@ -240,6 +272,7 @@ mod test { assert!(matches!(unescaped, Cow::Borrowed(_))); } + #[cfg(not(feature = "replace-invalid-utf8"))] #[test] fn unescape() { assert_eq!( @@ -290,6 +323,74 @@ mod test { ); } + #[cfg(feature = "replace-invalid-utf8")] + #[test] + fn unescape() { + assert_eq!( + "\x08f\x0Co\to\r\n", + unescape_string(r"\bf\fo\to\r\n").unwrap() + ); + assert_eq!("\"test\"", unescape_string(r#"\"test\""#).unwrap()); + assert_eq!("\x00", unescape_string("\\u0000").unwrap()); + assert_eq!("\x1f", unescape_string("\\u001f").unwrap()); + assert_eq!("foo\r\nbar", unescape_string("foo\\r\\nbar").unwrap()); + assert_eq!("foo\r\n", unescape_string("foo\\r\\n").unwrap()); + assert_eq!("\r\nbar", unescape_string("\\r\\nbar").unwrap()); + assert_eq!("\u{10437}", unescape_string("\\uD801\\uDC37").unwrap()); + + // New tests for invalid Unicode replacement + assert_eq!("�", unescape_string("\\uD800").unwrap()); // High surrogate without low surrogate + assert_eq!("�", unescape_string("\\uDC00").unwrap()); // Low surrogate without high surrogate + assert_eq!("��", unescape_string("\\uD800\\uD800").unwrap()); // Two high surrogates + assert_eq!("��", unescape_string("\\uDC00\\uDC00").unwrap()); // Two low surrogates + assert_eq!("test�test", unescape_string("test\\uD800test").unwrap()); // Orphaned surrogate in middle of string + assert_eq!( + "�\u{10437}", + unescape_string("\\uD800\\uD801\\uDC37").unwrap() + ); // Invalid then valid surrogate pair + + // These error cases should still work as before + assert_eq!( + Err(EscapeErrorKind::UnexpectedEndOfString.into()), + unescape_string("\\") + ); + assert_eq!( + Err(EscapeErrorKind::UnexpectedEndOfString.into()), + unescape_string("\\u") + ); + assert_eq!( + Err(EscapeErrorKind::UnexpectedEndOfString.into()), + unescape_string("\\u00") + ); + assert_eq!( + Err(EscapeErrorKind::InvalidEscapeCharacter('z').into()), + unescape_string("\\z") + ); + assert_eq!( + Err(EscapeErrorKind::InvalidUnicodeEscape("+04D".into()).into()), + unescape_string("\\u+04D") + ); + + // Regular character. + assert_eq!("A", unescape_string("\\u0041").unwrap()); + + // Single surrogates (should each become �). + assert_eq!("�", unescape_string("\\uD800").unwrap()); // High surrogate + assert_eq!("�", unescape_string("\\uDC00").unwrap()); // Low surrogate + + // Valid surrogate pair. + assert_eq!("🦀", unescape_string("\\uD83E\\uDD80").unwrap()); + + // Invalid pairs (should each become ��). + assert_eq!("��", unescape_string("\\uD800\\uD801").unwrap()); // High + High + assert_eq!("��", unescape_string("\\uDC00\\uDC01").unwrap()); // Low + Low + assert_eq!("��", unescape_string("\\uDC00\\uD800").unwrap()); // Low + High + + // Surrogate + non-surrogate. + assert_eq!("�A", unescape_string("\\uD800\\u0041").unwrap()); // High + ASCII + assert_eq!("�A", unescape_string("\\uDC00\\u0041").unwrap()); // Low + ASCII + } + use proptest::proptest; proptest! { #[test]