Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow server SDKs to replace invalid UTF-8 character with '�' #3996

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .changelog/1738663351.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
applies_to:
- server
authors:
- drganjoo
references: []
breaking: false
new_feature: true
bug_fix: false
---
Enhanced UTF-8 handling: When `replaceInvalidUtf8` codegen flag is enabled, invalid UTF-8 sequences are now automatically replaced with the Unicode replacement character (U+FFFD) instead of causing errors
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.SparseTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.escape
Expand Down Expand Up @@ -102,30 +101,30 @@ class JsonParserGenerator(
ReturnSymbolToParse(codegenContext.symbolProvider.toSymbol(shape), false)
},
private val customizations: List<JsonParserCustomization> = listOf(),
smithyJsonWithFeatureFlag: RuntimeType = RuntimeType.smithyJson(codegenContext.runtimeConfig),
) : StructuredDataParserGenerator {
private val model = codegenContext.model
private val symbolProvider = codegenContext.symbolProvider
private val runtimeConfig = codegenContext.runtimeConfig
private val codegenTarget = codegenContext.target
private val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType()
private val protocolFunctions = ProtocolFunctions(codegenContext)
private val builderInstantiator = codegenContext.builderInstantiator()
private val codegenScope =
arrayOf(
"Error" to smithyJson.resolve("deserialize::error::DeserializeError"),
"expect_blob_or_null" to smithyJson.resolve("deserialize::token::expect_blob_or_null"),
"expect_bool_or_null" to smithyJson.resolve("deserialize::token::expect_bool_or_null"),
"expect_document" to smithyJson.resolve("deserialize::token::expect_document"),
"expect_number_or_null" to smithyJson.resolve("deserialize::token::expect_number_or_null"),
"expect_start_array" to smithyJson.resolve("deserialize::token::expect_start_array"),
"expect_start_object" to smithyJson.resolve("deserialize::token::expect_start_object"),
"expect_string_or_null" to smithyJson.resolve("deserialize::token::expect_string_or_null"),
"expect_timestamp_or_null" to smithyJson.resolve("deserialize::token::expect_timestamp_or_null"),
"json_token_iter" to smithyJson.resolve("deserialize::json_token_iter"),
"Error" to smithyJsonWithFeatureFlag.resolve("deserialize::error::DeserializeError"),
"expect_blob_or_null" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_blob_or_null"),
"expect_bool_or_null" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_bool_or_null"),
"expect_document" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_document"),
"expect_number_or_null" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_number_or_null"),
"expect_start_array" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_start_array"),
"expect_start_object" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_start_object"),
"expect_string_or_null" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_string_or_null"),
"expect_timestamp_or_null" to smithyJsonWithFeatureFlag.resolve("deserialize::token::expect_timestamp_or_null"),
"json_token_iter" to smithyJsonWithFeatureFlag.resolve("deserialize::json_token_iter"),
"Peekable" to RuntimeType.std.resolve("iter::Peekable"),
"skip_value" to smithyJson.resolve("deserialize::token::skip_value"),
"skip_to_end" to smithyJson.resolve("deserialize::token::skip_to_end"),
"Token" to smithyJson.resolve("deserialize::Token"),
"skip_value" to smithyJsonWithFeatureFlag.resolve("deserialize::token::skip_value"),
"skip_to_end" to smithyJsonWithFeatureFlag.resolve("deserialize::token::skip_to_end"),
"Token" to smithyJsonWithFeatureFlag.resolve("deserialize::Token"),
"or_empty" to orEmptyJson(),
*preludeScope,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,19 @@

package software.amazon.smithy.rust.codegen.core.testutil

import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.build.PluginContext
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.node.ToNode
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.AbstractTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.util.runCommand
import java.io.File
Expand Down Expand Up @@ -50,87 +60,78 @@ data class IntegrationTestParams(
sealed class AdditionalSettings {
abstract fun toObjectNode(): ObjectNode

abstract class CoreAdditionalSettings protected constructor(val settings: List<AdditionalSettings>) : AdditionalSettings() {
override fun toObjectNode(): ObjectNode {
val merged =
settings.map { it.toObjectNode() }
.reduce { acc, next -> acc.merge(next) }

return ObjectNode.builder()
.withMember("codegen", merged)
companion object {
private fun Map<String, Any>.toCodegenObjectNode(): ObjectNode =
ObjectNode.builder()
.withMember(
"codegen",
ObjectNode.builder().apply {
forEach { (key, value) ->
when (value) {
is Boolean -> withMember(key, value)
is Number -> withMember(key, value)
is String -> withMember(key, value)
is ToNode -> withMember(key, value)
else -> throw IllegalArgumentException("Unsupported type for key $key: ${value::class}")
}
}
}.build(),
)
.build()
}
}

abstract class Builder<T : CoreAdditionalSettings> : AdditionalSettings() {
protected val settings = mutableListOf<AdditionalSettings>()
abstract class CoreAdditionalSettings protected constructor(
private val settings: Map<String, Any>,
) : AdditionalSettings() {
override fun toObjectNode(): ObjectNode = settings.toCodegenObjectNode()

fun generateCodegenComments(debugMode: Boolean = true): Builder<T> {
settings.add(GenerateCodegenComments(debugMode))
return this
}

abstract fun build(): T
abstract class Builder<T : CoreAdditionalSettings> : AdditionalSettings() {
protected val settings = mutableMapOf<String, Any>()

override fun toObjectNode(): ObjectNode = build().toObjectNode()
}
fun generateCodegenComments(debugMode: Boolean = true) =
apply {
settings["debugMode"] = debugMode
}

// Core settings that are common to both Servers and Clients should be defined here.
data class GenerateCodegenComments(val debugMode: Boolean) : AdditionalSettings() {
override fun toObjectNode(): ObjectNode =
ObjectNode.builder()
.withMember("debugMode", debugMode)
.build()
override fun toObjectNode(): ObjectNode = settings.toCodegenObjectNode()
}
}
}

class ClientAdditionalSettings private constructor(settings: List<AdditionalSettings>) :
AdditionalSettings.CoreAdditionalSettings(settings) {
class Builder : CoreAdditionalSettings.Builder<ClientAdditionalSettings>() {
override fun build(): ClientAdditionalSettings = ClientAdditionalSettings(settings)
}

// Additional settings that are specific to client generation should be defined here.

companion object {
fun builder() = Builder()
}
}

class ServerAdditionalSettings private constructor(settings: List<AdditionalSettings>) :
AdditionalSettings.CoreAdditionalSettings(settings) {
class Builder : CoreAdditionalSettings.Builder<ServerAdditionalSettings>() {
fun publicConstrainedTypes(enabled: Boolean = true): Builder {
settings.add(PublicConstrainedTypes(enabled))
return this
class ServerAdditionalSettings private constructor(
settings: Map<String, Any>,
) : AdditionalSettings.CoreAdditionalSettings(settings) {
class Builder : CoreAdditionalSettings.Builder<ServerAdditionalSettings>() {
fun publicConstrainedTypes(enabled: Boolean = true) =
apply {
settings["publicConstrainedTypes"] = enabled
}

fun addValidationExceptionToConstrainedOperations(enabled: Boolean = true): Builder {
settings.add(AddValidationExceptionToConstrainedOperations(enabled))
return this
fun addValidationExceptionToConstrainedOperations(enabled: Boolean = true) =
apply {
settings["addValidationExceptionToConstrainedOperations"] = enabled
}

override fun build(): ServerAdditionalSettings = ServerAdditionalSettings(settings)
}
fun replaceInvalidUtf8(enabled: Boolean = true) =
apply {
settings["replaceInvalidUtf8"] = enabled
}
}

private data class PublicConstrainedTypes(val enabled: Boolean) : AdditionalSettings() {
override fun toObjectNode(): ObjectNode =
ObjectNode.builder()
.withMember("publicConstrainedTypes", enabled)
.build()
}
companion object {
fun builder() = Builder()
}
}

private data class AddValidationExceptionToConstrainedOperations(val enabled: Boolean) : AdditionalSettings() {
override fun toObjectNode(): ObjectNode =
ObjectNode.builder()
.withMember("addValidationExceptionToConstrainedOperations", enabled)
.build()
}
class ClientAdditionalSettings private constructor(
settings: Map<String, Any>,
) : AdditionalSettings.CoreAdditionalSettings(settings) {
class Builder : CoreAdditionalSettings.Builder<ClientAdditionalSettings>()

companion object {
fun builder() = Builder()
}
companion object {
fun builder() = Builder()
}
}

/**
* Run cargo test on a true, end-to-end, codegen product of a given model.
Expand Down Expand Up @@ -161,3 +162,128 @@ fun codegenIntegrationTest(
logger.fine(out.toString())
return testDir
}

/**
* Metadata associated with a protocol that provides additional information needed for testing.
*
* @property protocol The protocol enum value this metadata is associated with
* @property contentType The HTTP Content-Type header value associated with this protocol.
*/
data class ProtocolMetadata(
val protocol: ModelProtocol,
val contentType: String,
)

/**
* Represents the supported protocol traits in Smithy models.
*
* @property trait The Smithy trait instance with which the service shape must be annotated.
*/
enum class ModelProtocol(val trait: AbstractTrait) {
AwsJson10(AwsJson1_0Trait.builder().build()),
AwsJson11(AwsJson1_1Trait.builder().build()),
RestJson(RestJson1Trait.builder().build()),
RestXml(RestXmlTrait.builder().build()),
Rpcv2Cbor(Rpcv2CborTrait.builder().build()),
;

// Create metadata after enum is initialized
val metadata: ProtocolMetadata by lazy {
when (this) {
AwsJson10 -> ProtocolMetadata(this, "application/x-amz-json-1.0")
AwsJson11 -> ProtocolMetadata(this, "application/x-amz-json-1.1")
RestJson -> ProtocolMetadata(this, "application/json")
RestXml -> ProtocolMetadata(this, "application/xml")
Rpcv2Cbor -> ProtocolMetadata(this, "application/cbor")
}
}

companion object {
private val TRAIT_IDS = values().map { it.trait.toShapeId() }.toSet()
val ALL: Set<ModelProtocol> = values().toSet()

fun getTraitIds() = TRAIT_IDS
}
}

/**
* Removes all existing protocol traits annotated on the given service,
* then sets the provided `protocol` as the sole protocol trait for the service.
*/
fun Model.replaceProtocolTraitOnServerShapeId(
serviceShapeId: ShapeId,
modelProtocol: ModelProtocol,
): Model {
val serviceShape = this.expectShape(serviceShapeId, ServiceShape::class.java)
return replaceProtocolTraitOnServiceShape(serviceShape, modelProtocol)
}

/**
* Removes all existing protocol traits annotated on the given service shape,
* then sets the provided `protocol` as the sole protocol trait for the service.
*/
fun Model.replaceProtocolTraitOnServiceShape(
serviceShape: ServiceShape,
modelProtocol: ModelProtocol,
): Model {
val serviceBuilder = serviceShape.toBuilder()
ModelProtocol.getTraitIds().forEach { traitId ->
serviceBuilder.removeTrait(traitId)
}
val service = serviceBuilder.addTrait(modelProtocol.trait).build()
return ModelTransformer.create().replaceShapes(this, listOf(service))
}

/**
* Processes a Smithy model string by applying different protocol traits and invoking the tests block on the model.
* For each protocol, this function:
* 1. Parses the Smithy model string
* 2. Replaces any existing protocol traits on service shapes with the specified protocol
* 3. Runs the provided test with the transformed model and protocol metadata
*
* @param protocolTraitIds Set of protocols to test against
* @param test Function that receives the transformed model and protocol metadata for testing
*/
fun String.forProtocols(
protocolTraitIds: Set<ModelProtocol>,
test: (Model, ProtocolMetadata) -> Unit,
) {
val baseModel = this.asSmithyModel(smithyVersion = "2")
val serviceShapes = baseModel.serviceShapes.toList()

protocolTraitIds.forEach { protocol ->
val transformedModel =
serviceShapes.fold(baseModel) { acc, shape ->
acc.replaceProtocolTraitOnServiceShape(shape, protocol)
}
test(transformedModel, protocol.metadata)
}
}

/**
* Convenience overload that accepts vararg protocols instead of a Set.
*
* @param protocols Variable number of protocols to test against
* @param test Function that receives the transformed model and protocol metadata for testing
* @see forProtocols
*/
fun String.forProtocols(
vararg protocols: ModelProtocol,
test: (Model, ProtocolMetadata) -> Unit,
) {
forProtocols(protocols.toSet(), test)
}

/**
* Tests a Smithy model string against all supported protocols, with optional exclusions.
*
* @param exclude Set of protocols to exclude from testing (default is empty)
* @param test Function that receives the transformed model and protocol metadata for testing
* @see forProtocols
*/
fun String.forAllProtocols(
exclude: Set<ModelProtocol> = emptySet(),
test: (Model, ProtocolMetadata) -> Unit,
) {
forProtocols(ModelProtocol.ALL - exclude, test)
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ data class ServerCodegenConfig(
*/
val experimentalCustomValidationExceptionWithReasonPleaseDoNotUse: String? = defaultExperimentalCustomValidationExceptionWithReasonPleaseDoNotUse,
val addValidationExceptionToConstrainedOperations: Boolean = DEFAULT_ADD_VALIDATION_EXCEPTION_TO_CONSTRAINED_OPERATIONS,
val replaceInvalidUtf8: Boolean = DEFAULT_REPLACE_INVALID_UTF8,
) : CoreCodegenConfig(
formatTimeoutSeconds, debugMode,
) {
Expand All @@ -105,6 +106,7 @@ data class ServerCodegenConfig(
private const val DEFAULT_IGNORE_UNSUPPORTED_CONSTRAINTS = false
private val defaultExperimentalCustomValidationExceptionWithReasonPleaseDoNotUse = null
private const val DEFAULT_ADD_VALIDATION_EXCEPTION_TO_CONSTRAINED_OPERATIONS = false
private const val DEFAULT_REPLACE_INVALID_UTF8 = false

fun fromCodegenConfigAndNode(
coreCodegenConfig: CoreCodegenConfig,
Expand All @@ -117,6 +119,7 @@ data class ServerCodegenConfig(
ignoreUnsupportedConstraints = node.get().getBooleanMemberOrDefault("ignoreUnsupportedConstraints", DEFAULT_IGNORE_UNSUPPORTED_CONSTRAINTS),
experimentalCustomValidationExceptionWithReasonPleaseDoNotUse = node.get().getStringMemberOrDefault("experimentalCustomValidationExceptionWithReasonPleaseDoNotUse", defaultExperimentalCustomValidationExceptionWithReasonPleaseDoNotUse),
addValidationExceptionToConstrainedOperations = node.get().getBooleanMemberOrDefault("addValidationExceptionToConstrainedOperations", DEFAULT_ADD_VALIDATION_EXCEPTION_TO_CONSTRAINED_OPERATIONS),
replaceInvalidUtf8 = node.get().getBooleanMemberOrDefault("replaceInvalidUtf8", DEFAULT_REPLACE_INVALID_UTF8),
)
} else {
ServerCodegenConfig(
Expand Down
Loading
Loading