From d54312ed8c0d83a70d32febe413287dfc02be82a Mon Sep 17 00:00:00 2001 From: Chuckame Date: Tue, 10 Sep 2024 23:50:35 +0200 Subject: [PATCH 01/13] chore(benchmark): add nullable field in simple benchmark --- benchmark/README.md | 38 +++++++++---------- benchmark/build.gradle.kts | 9 +++++ .../avrokotlin/benchmark/ManualProfiling.kt | 3 +- .../benchmark/internal/SimpleDataClass.kt | 4 +- 4 files changed, 31 insertions(+), 23 deletions(-) diff --git a/benchmark/README.md b/benchmark/README.md index c22f22e1..98548df5 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -25,25 +25,25 @@ Each benchmark is executed with the following configuration: Computer: Macbook air M2 ``` -Benchmark Mode Cnt Score Error Units Relative Difference (%) -c.g.a.b.complex.Avro4kBenchmark.read thrpt 5 27390.075 ± 3170.975 ops/s 0.00% -c.g.a.b.complex.ApacheAvroReflectBenchmark.read thrpt 5 26239.615 ± 15033.290 ops/s -4.20% -c.g.a.b.complex.Avro4kGenericWithApacheAvroBenchmark.read thrpt 5 15412.821 ± 860.250 ops/s -43.71% - -c.g.a.b.complex.Avro4kBenchmark.write thrpt 5 55183.133 ± 1994.669 ops/s 0.00% -c.g.a.b.complex.ApacheAvroReflectBenchmark.write thrpt 5 47510.885 ± 2467.348 ops/s -13.91% -c.g.a.b.complex.JacksonAvroBenchmark.write thrpt 5 33936.765 ± 2139.528 ops/s -38.50% -c.g.a.b.complex.Avro4kGenericWithApacheAvroBenchmark.write thrpt 5 27673.527 ± 1605.319 ops/s -49.84% - -c.g.a.b.simple.Avro4kSimpleBenchmark.read thrpt 5 276484.628 ± 15593.092 ops/s 0.00% -c.g.a.b.simple.ApacheAvroReflectSimpleBenchmark.read thrpt 5 230744.377 ± 30164.628 ops/s -16.53% -c.g.a.b.simple.Avro4kGenericWithApacheAvroSimpleBenchmark.read thrpt 5 167888.837 ± 14439.479 ops/s -39.27% -c.g.a.b.simple.JacksonAvroSimpleBenchmark.read thrpt 5 69615.099 ± 4047.717 ops/s -74.82% - -c.g.a.b.simple.Avro4kSimpleBenchmark.write thrpt 5 422469.311 ± 4816.353 ops/s 0.00% -c.g.a.b.simple.ApacheAvroReflectSimpleBenchmark.write thrpt 5 320367.673 ± 33394.537 ops/s -24.15% -c.g.a.b.simple.Avro4kGenericWithApacheAvroSimpleBenchmark.write thrpt 5 186399.540 ± 8931.966 ops/s -55.88% -c.g.a.b.simple.JacksonAvroSimpleBenchmark.write thrpt 5 138898.312 ± 9156.715 ops/s -67.11% +Benchmark Mode Cnt Score Error Units +c.g.a.b.complex.Avro4kBenchmark.read thrpt 5 27482.418 ± 1162.064 ops/s +c.g.a.b.complex.ApacheAvroReflectBenchmark.read thrpt 5 26239.615 ± 15033.290 ops/s +c.g.a.b.complex.Avro4kGenericWithApacheAvroBenchmark.read thrpt 5 15862.270 ± 1139.036 ops/s + +c.g.a.b.complex.Avro4kBenchmark.write thrpt 5 54335.043 ± 2481.196 ops/s +c.g.a.b.complex.ApacheAvroReflectBenchmark.write thrpt 5 47510.885 ± 2467.348 ops/s +c.g.a.b.complex.JacksonAvroBenchmark.write thrpt 5 33936.765 ± 2139.528 ops/s +c.g.a.b.complex.Avro4kGenericWithApacheAvroBenchmark.write thrpt 5 27124.366 ± 753.406 ops/s + +c.g.a.b.simple.Avro4kSimpleBenchmark.read thrpt 5 215140.198 ± 9182.259 ops/s +c.g.a.b.simple.ApacheAvroReflectSimpleBenchmark.read thrpt 5 230744.377 ± 30164.628 ops/s +c.g.a.b.simple.Avro4kGenericWithApacheAvroSimpleBenchmark.read thrpt 5 136913.851 ± 8302.833 ops/s +c.g.a.b.simple.JacksonAvroSimpleBenchmark.read thrpt 5 69615.099 ± 4047.717 ops/s + +c.g.a.b.simple.Avro4kSimpleBenchmark.write thrpt 5 354497.179 ± 8342.002 ops/s +c.g.a.b.simple.ApacheAvroReflectSimpleBenchmark.write thrpt 5 320367.673 ± 33394.537 ops/s +c.g.a.b.simple.Avro4kGenericWithApacheAvroSimpleBenchmark.write thrpt 5 142525.233 ± 2796.318 ops/s +c.g.a.b.simple.JacksonAvroSimpleBenchmark.write thrpt 5 138898.312 ± 9156.715 ops/s ``` > [!WARNING] diff --git a/benchmark/build.gradle.kts b/benchmark/build.gradle.kts index 283f68d0..54839b25 100644 --- a/benchmark/build.gradle.kts +++ b/benchmark/build.gradle.kts @@ -34,6 +34,15 @@ benchmark { register("complex-write") { include("^com.github.avrokotlin.benchmark.complex.+.write$") } + register("avro4k") { + include("Avro4k") + } + register("avro4k-read") { + include("Avro4k.+read") + } + register("avro4k-write") { + include("Avro4k.+write") + } } targets { register("main") { diff --git a/benchmark/src/main/kotlin/com/github/avrokotlin/benchmark/ManualProfiling.kt b/benchmark/src/main/kotlin/com/github/avrokotlin/benchmark/ManualProfiling.kt index 8b155976..6997d8d2 100644 --- a/benchmark/src/main/kotlin/com/github/avrokotlin/benchmark/ManualProfiling.kt +++ b/benchmark/src/main/kotlin/com/github/avrokotlin/benchmark/ManualProfiling.kt @@ -1,12 +1,11 @@ package com.github.avrokotlin.benchmark import com.github.avrokotlin.benchmark.complex.Avro4kBenchmark -import com.github.avrokotlin.benchmark.simple.JacksonAvroSimpleBenchmark internal object ManualProfilingWrite { @JvmStatic fun main(vararg args: String) { - JacksonAvroSimpleBenchmark().apply { + Avro4kBenchmark().apply { initTestData() for (i in 0 until 1_000_000) { if (i % 1_000 == 0) println("Iteration $i") diff --git a/benchmark/src/main/kotlin/com/github/avrokotlin/benchmark/internal/SimpleDataClass.kt b/benchmark/src/main/kotlin/com/github/avrokotlin/benchmark/internal/SimpleDataClass.kt index 1e8f1b4c..3adb7a7e 100644 --- a/benchmark/src/main/kotlin/com/github/avrokotlin/benchmark/internal/SimpleDataClass.kt +++ b/benchmark/src/main/kotlin/com/github/avrokotlin/benchmark/internal/SimpleDataClass.kt @@ -7,11 +7,11 @@ internal data class SimpleDataClass( val bool: Boolean, val byte: Byte, val short: Short, - val int: Int, + val int: Int?, val long: Long, val float: Float, val double: Double, - val string: String, + val string: String?, val bytes: ByteArray, ) { companion object { From bf59ea61d9e3031bf1a5d18b608c907261af24b1 Mon Sep 17 00:00:00 2001 From: Chuckame Date: Fri, 13 Sep 2024 16:28:12 +0200 Subject: [PATCH 02/13] refactor: rework encoding for more clear & compact resolving unions --- api/avro4k-core.api | 21 +- .../com/github/avrokotlin/avro4k/Avro.kt | 2 +- .../github/avrokotlin/avro4k/AvroEncoder.kt | 199 ++++++-- .../avro4k/internal/RecordResolver.kt | 155 +++---- .../decoder/direct/RecordDirectDecoder.kt | 4 +- .../decoder/generic/RecordGenericDecoder.kt | 4 +- .../internal/encoder/AbstractAvroEncoder.kt | 372 +++++++++++++++ .../encoder/ReorderingCompositeEncoder.kt | 297 ++++++++++++ .../direct/AbstractAvroDirectEncoder.kt | 414 ++--------------- .../encoder/direct/RecordDirectEncoder.kt | 218 +++------ .../generic/AbstractAvroGenericEncoder.kt | 427 ++---------------- .../encoder/generic/ArrayGenericEncoder.kt | 16 +- .../generic/AvroValueGenericEncoder.kt | 15 +- .../encoder/generic/BytesGenericEncoder.kt | 26 -- .../encoder/generic/FixedGenericEncoder.kt | 37 -- .../encoder/generic/MapGenericEncoder.kt | 2 +- .../encoder/generic/RecordGenericEncoder.kt | 41 +- .../avrokotlin/avro4k/internal/exceptions.kt | 18 - .../avrokotlin/avro4k/internal/helpers.kt | 9 +- .../avro4k/serializer/AvroDuration.kt | 41 +- .../serializer/JavaStdLibSerializers.kt | 122 ++--- .../avro4k/serializer/JavaTimeSerializers.kt | 182 +++----- 22 files changed, 1178 insertions(+), 1444 deletions(-) create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/ReorderingCompositeEncoder.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/BytesGenericEncoder.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/FixedGenericEncoder.kt diff --git a/api/avro4k-core.api b/api/avro4k-core.api index d1c52c61..389560bc 100644 --- a/api/avro4k-core.api +++ b/api/avro4k-core.api @@ -10,7 +10,7 @@ public abstract class com/github/avrokotlin/avro4k/Avro : kotlinx/serialization/ public fun encodeToByteArray (Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)[B public final fun encodeToByteArray (Lorg/apache/avro/Schema;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)[B public final fun getConfiguration ()Lcom/github/avrokotlin/avro4k/AvroConfiguration; - public fun getSerializersModule ()Lkotlinx/serialization/modules/SerializersModule; + public final fun getSerializersModule ()Lkotlinx/serialization/modules/SerializersModule; public final fun schema (Lkotlinx/serialization/descriptors/SerialDescriptor;)Lorg/apache/avro/Schema; } @@ -113,10 +113,9 @@ public synthetic class com/github/avrokotlin/avro4k/AvroDoc$Impl : com/github/av } public abstract interface class com/github/avrokotlin/avro4k/AvroEncoder : kotlinx/serialization/encoding/Encoder { - public abstract fun encodeBytes (Ljava/nio/ByteBuffer;)V public abstract fun encodeBytes ([B)V - public abstract fun encodeFixed (Lorg/apache/avro/generic/GenericFixed;)V public abstract fun encodeFixed ([B)V + public abstract fun encodeUnionIndex (I)V public abstract fun getCurrentWriterSchema ()Lorg/apache/avro/Schema; } @@ -127,11 +126,6 @@ public final class com/github/avrokotlin/avro4k/AvroEncoder$DefaultImpls { public static fun encodeSerializableValue (Lcom/github/avrokotlin/avro4k/AvroEncoder;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)V } -public final class com/github/avrokotlin/avro4k/AvroEncoderKt { - public static final fun encodeResolving (Lcom/github/avrokotlin/avro4k/AvroEncoder;Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object; - public static final fun resolveUnion (Lcom/github/avrokotlin/avro4k/AvroEncoder;Lorg/apache/avro/Schema;Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object; -} - public abstract interface annotation class com/github/avrokotlin/avro4k/AvroEnumDefault : java/lang/annotation/Annotation { } @@ -330,17 +324,6 @@ public final class com/github/avrokotlin/avro4k/UnionDecoder$DefaultImpls { public static fun decodeSerializableValue (Lcom/github/avrokotlin/avro4k/UnionDecoder;Lkotlinx/serialization/DeserializationStrategy;)Ljava/lang/Object; } -public abstract interface class com/github/avrokotlin/avro4k/UnionEncoder : com/github/avrokotlin/avro4k/AvroEncoder { - public abstract fun encodeUnionIndex (I)V -} - -public final class com/github/avrokotlin/avro4k/UnionEncoder$DefaultImpls { - public static fun beginCollection (Lcom/github/avrokotlin/avro4k/UnionEncoder;Lkotlinx/serialization/descriptors/SerialDescriptor;I)Lkotlinx/serialization/encoding/CompositeEncoder; - public static fun encodeNotNullMark (Lcom/github/avrokotlin/avro4k/UnionEncoder;)V - public static fun encodeNullableSerializableValue (Lcom/github/avrokotlin/avro4k/UnionEncoder;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)V - public static fun encodeSerializableValue (Lcom/github/avrokotlin/avro4k/UnionEncoder;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)V -} - public final class com/github/avrokotlin/avro4k/serializer/AvroDuration { public static final field Companion Lcom/github/avrokotlin/avro4k/serializer/AvroDuration$Companion; public synthetic fun (IIILkotlin/jvm/internal/DefaultConstructorMarker;)V diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt index e09c0dc7..84ddfdaf 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt @@ -29,7 +29,7 @@ import org.apache.avro.util.WeakIdentityHashMap */ public sealed class Avro( public val configuration: AvroConfiguration, - public override val serializersModule: SerializersModule, + public final override val serializersModule: SerializersModule, ) : BinaryFormat { // We use the identity hash map because we could have multiple descriptors with the same name, especially // when having 2 different version of the schema for the same name. kotlinx-serialization is instantiating the descriptors diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroEncoder.kt index 1e791ec2..7a6dbfbd 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroEncoder.kt @@ -1,10 +1,13 @@ package com.github.avrokotlin.avro4k +import com.github.avrokotlin.avro4k.internal.aliases +import com.github.avrokotlin.avro4k.internal.isNamedSchema +import com.github.avrokotlin.avro4k.internal.nonNullSerialName import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerializationException +import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.encoding.Encoder import org.apache.avro.Schema -import org.apache.avro.generic.GenericFixed -import java.nio.ByteBuffer /** * Interface to encode Avro values. @@ -22,9 +25,6 @@ import java.nio.ByteBuffer * - [encodeEnum] * - [encodeBytes] * - [encodeFixed] - * - * Use the following methods to allow complex encoding using raw values, mainly for logical types: - * - [encodeResolving] */ public interface AvroEncoder : Encoder { /** @@ -33,12 +33,6 @@ public interface AvroEncoder : Encoder { @ExperimentalSerializationApi public val currentWriterSchema: Schema - /** - * Encodes a [Schema.Type.BYTES] value from a [ByteBuffer]. - */ - @ExperimentalSerializationApi - public fun encodeBytes(value: ByteBuffer) - /** * Encodes a [Schema.Type.BYTES] value from a [ByteArray]. */ @@ -47,61 +41,168 @@ public interface AvroEncoder : Encoder { /** * Encodes a [Schema.Type.FIXED] value from a [ByteArray]. Its size must match the size of the fixed schema in [currentWriterSchema]. + * When many fixed schemas are in a union, the first one that matches the size is selected. To avoid this auto-selection, use [encodeUnionIndex] with the index of the expected fixed schema. */ @ExperimentalSerializationApi public fun encodeFixed(value: ByteArray) /** - * Encodes a [Schema.Type.FIXED] value from a [GenericFixed]. Its size must match the size of the fixed schema in [currentWriterSchema]. + * Selects the index of the union type to encode. Also sets [currentWriterSchema] to the selected type. */ @ExperimentalSerializationApi - public fun encodeFixed(value: GenericFixed) + public fun encodeUnionIndex(index: Int) } -@PublishedApi -internal interface UnionEncoder : AvroEncoder { - /** - * Encode the selected union schema and set the selected type in [currentWriterSchema]. - */ - fun encodeUnionIndex(index: Int) +internal fun AvroEncoder.namedSchemaNotFoundInUnionError( + expectedName: String, + possibleAliases: Set, + vararg fallbackTypes: Schema.Type, +): Throwable { + val aliasesStr = if (possibleAliases.isNotEmpty()) " (with aliases ${possibleAliases.joinToString()})" else "" + val fallbacksStr = if (fallbackTypes.isNotEmpty()) " Also no compatible type found (one of ${fallbackTypes.joinToString()})." else "" + return SerializationException("Named schema $expectedName$aliasesStr not found in union.$fallbacksStr Actual schema: $currentWriterSchema") +} + +internal fun AvroEncoder.typeNotFoundInUnionError( + mainType: Schema.Type, + vararg fallbackTypes: Schema.Type, +): Throwable { + val fallbacksStr = if (fallbackTypes.isNotEmpty()) " Also no compatible type found (one of ${fallbackTypes.joinToString()})." else "" + return SerializationException("${mainType.getName().replaceFirstChar { it.uppercase() }} type not found in union.$fallbacksStr Actual schema: $currentWriterSchema") +} + +internal fun AvroEncoder.unsupportedWriterTypeError( + mainType: Schema.Type, + vararg fallbackTypes: Schema.Type, +): Throwable { + val fallbacksStr = if (fallbackTypes.isNotEmpty()) ", and also not matching to any compatible type (one of ${fallbackTypes.joinToString()})." else "" + return SerializationException( + "Unsupported schema '${currentWriterSchema.fullName}' for encoded type of ${mainType.getName()}$fallbacksStr. Actual schema: $currentWriterSchema" + ) +} + +internal fun AvroEncoder.ensureFixedSize(byteArray: ByteArray): ByteArray { + if (currentWriterSchema.fixedSize != byteArray.size) { + throw SerializationException("Fixed size mismatch for actual size of ${byteArray.size}. Actual schema: $currentWriterSchema") + } + return byteArray +} + +internal fun AvroEncoder.fullNameOrAliasMismatchError( + fullName: String, + aliases: Set, +): Throwable { + val aliasesStr = if (aliases.isNotEmpty()) " (with aliases ${aliases.joinToString()})" else "" + return SerializationException("The descriptor $fullName$aliasesStr doesn't match the schema $currentWriterSchema") +} + +internal fun AvroEncoder.logicalTypeMismatchError( + logicalType: String, + type: Schema.Type, +): Throwable { + return SerializationException("Expected schema type of ${type.getName()} with logical type $logicalType but had schema $currentWriterSchema") } /** - * Allows you to encode a value differently depending on the schema (generally its name, type, logicalType). - * If the [AvroEncoder.currentWriterSchema] is a union, it takes **the first matching encoder** as the final encoder. - * - * This reduces the need to manually resolve the type in a union **and** not in a union. - * - * For examples, see the [com.github.avrokotlin.avro4k.serializer.BigDecimalSerializer] as it resolves a lot of types and also logical types. - * - * @param resolver A lambda that returns a lambda (the encoding lambda) that contains the logic to encode the value only when the schema matches. The encoding **MUST** be done in the encoder lambda to avoid encoding the value if it is not the right schema. Return null when it is not matching the expected schema. - * @param error A lambda that throws an exception if the encoder cannot be resolved. + * @return true is union is nullable and non-null type was selected, false otherwise */ -@ExperimentalSerializationApi -public inline fun AvroEncoder.encodeResolving( - error: () -> Throwable, - resolver: (Schema) -> (() -> T)?, -): T { - val schema = currentWriterSchema - return if (schema.isUnion) { - resolveUnion(schema, error, resolver) +internal fun AvroEncoder.trySelectSingleNonNullTypeFromUnion(): Boolean { + return if (currentWriterSchema.types.size == 2) { + // optimization: A nullable union is very common + if (currentWriterSchema.types[0].type == Schema.Type.NULL) { + encodeUnionIndex(1) + true + } else if (currentWriterSchema.types[1].type == Schema.Type.NULL) { + encodeUnionIndex(0) + true + } else { + // we are in case of non-nullable union with only 2 types + false + } } else { - resolver(schema)?.invoke() ?: throw error() + false } } -@PublishedApi -internal inline fun AvroEncoder.resolveUnion( - schema: Schema, - error: () -> Throwable, - resolver: (Schema) -> (() -> T)?, -): T { - for (index in schema.types.indices) { - val subSchema = schema.types[index] - resolver(subSchema)?.let { - (this as UnionEncoder).encodeUnionIndex(index) - return it.invoke() +internal fun AvroEncoder.trySelectTypeFromUnion(vararg oneOf: Schema.Type): Boolean { + val index = + currentWriterSchema.getIndexTyped(*oneOf) + ?: return false + encodeUnionIndex(index) + return true +} + +internal fun AvroEncoder.trySelectFixedSchemaForSize(fixedSize: Int): Boolean { + currentWriterSchema.types.forEachIndexed { index, schema -> + if (schema.type == Schema.Type.FIXED && schema.fixedSize == fixedSize) { + encodeUnionIndex(index) + return true + } + } + return false +} + +internal fun AvroEncoder.trySelectEnumSchemaForSymbol(symbol: String): Boolean { + currentWriterSchema.types.forEachIndexed { index, schema -> + if (schema.type == Schema.Type.ENUM && schema.hasEnumSymbol(symbol)) { + encodeUnionIndex(index) + return true + } + } + return false +} + +internal fun AvroEncoder.trySelectNamedSchema(descriptor: SerialDescriptor): Boolean { + return trySelectNamedSchema(descriptor.nonNullSerialName, descriptor::aliases) +} + +internal fun AvroEncoder.trySelectNamedSchema( + name: String, + aliases: () -> Set = ::emptySet, +): Boolean { + val index = + currentWriterSchema.getIndexNamedOrAliased(name) + ?: aliases().firstNotNullOfOrNull { currentWriterSchema.getIndexNamedOrAliased(it) } + if (index != null) { + encodeUnionIndex(index) + return true + } + return false +} + +internal fun AvroEncoder.trySelectLogicalTypeFromUnion( + logicalTypeName: String, + vararg oneOf: Schema.Type, +): Boolean { + val index = + currentWriterSchema.getIndexLogicallyTyped(logicalTypeName, *oneOf) + ?: return false + encodeUnionIndex(index) + return true +} + +internal fun Schema.getIndexLogicallyTyped( + logicalTypeName: String, + vararg oneOf: Schema.Type, +): Int? { + return oneOf.firstNotNullOfOrNull { expectedType -> + when (expectedType) { + Schema.Type.FIXED, Schema.Type.RECORD, Schema.Type.ENUM -> types.indexOfFirst { it.type == expectedType && it.logicalType?.name == logicalTypeName }.takeIf { it >= 0 } + else -> getIndexNamed(expectedType.getName())?.takeIf { types[it].logicalType?.name == logicalTypeName } + } + } +} + +internal fun Schema.getIndexNamedOrAliased(expectedName: String): Int? { + return getIndexNamed(expectedName) + ?: types.indexOfFirst { it.isNamedSchema() && it.aliases.contains(expectedName) }.takeIf { it >= 0 } +} + +internal fun Schema.getIndexTyped(vararg oneOf: Schema.Type): Int? { + return oneOf.firstNotNullOfOrNull { expectedType -> + when (expectedType) { + Schema.Type.FIXED, Schema.Type.RECORD, Schema.Type.ENUM -> types.indexOfFirst { it.type == expectedType }.takeIf { it >= 0 } + else -> getIndexNamed(expectedType.getName()) } } - throw error() } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt index 1614fb47..95269416 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt @@ -3,6 +3,7 @@ package com.github.avrokotlin.avro4k.internal import com.github.avrokotlin.avro4k.Avro import com.github.avrokotlin.avro4k.AvroAlias import com.github.avrokotlin.avro4k.AvroDefault +import com.github.avrokotlin.avro4k.internal.encoder.ReorderingCompositeEncoder import com.github.avrokotlin.avro4k.internal.schema.CHAR_LOGICAL_TYPE_NAME import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.SerialDescriptor @@ -31,7 +32,7 @@ internal class RecordResolver( * * Note: We use the descriptor in the key as we could have multiple descriptors for the same record schema, and multiple record schemas for the same descriptor. */ - private val fieldCache: MutableMap> = WeakIdentityHashMap() + private val fieldCache: MutableMap> = WeakIdentityHashMap() /** * Maps the class fields to the schema fields. @@ -48,9 +49,9 @@ internal class RecordResolver( fun resolveFields( writerSchema: Schema, classDescriptor: SerialDescriptor, - ): ClassDescriptorForWriterSchema { - if (classDescriptor.elementsCount == 0) { - return ClassDescriptorForWriterSchema.EMPTY + ): SerializationWorkflow { + if (classDescriptor.elementsCount == 0 && writerSchema.fields.isEmpty()) { + return SerializationWorkflow.EMPTY } return fieldCache.getOrPut(classDescriptor) { WeakHashMap() }.getOrPut(writerSchema) { loadCache(classDescriptor, writerSchema) @@ -69,45 +70,15 @@ internal class RecordResolver( private fun loadCache( classDescriptor: SerialDescriptor, writerSchema: Schema, - ): ClassDescriptorForWriterSchema { + ): SerializationWorkflow { val readerSchema = avro.schema(classDescriptor) - val encodingSteps = computeEncodingSteps(classDescriptor, writerSchema) - return ClassDescriptorForWriterSchema( - sequentialEncoding = encodingSteps.areWriterFieldsSequentiallyOrdered(), + return SerializationWorkflow( computeDecodingSteps(classDescriptor, writerSchema, readerSchema), - encodingSteps + computeEncodingWorkflow(classDescriptor, writerSchema) ) } - private fun Array.areWriterFieldsSequentiallyOrdered(): Boolean { - var lastWriterFieldIndex = -1 - forEach { step -> - when (step) { - is EncodingStep.SerializeWriterField -> { - if (step.writerFieldIndex > lastWriterFieldIndex) { - lastWriterFieldIndex = step.writerFieldIndex - } else { - return false - } - } - - is EncodingStep.MissingWriterFieldFailure -> { - if (step.writerFieldIndex > lastWriterFieldIndex) { - lastWriterFieldIndex = step.writerFieldIndex - } else { - return false - } - } - - is EncodingStep.IgnoreElement -> { - // nothing to check - } - } - } - return true - } - private fun computeDecodingSteps( classDescriptor: SerialDescriptor, writerSchema: Schema, @@ -175,18 +146,19 @@ internal class RecordResolver( return decodingSteps.toTypedArray() } - private fun Schema.isTypeOf(expectedType: Schema.Type): Boolean { - return asSchemaList().any { it.type === expectedType } - } + private fun Schema.isTypeOf(expectedType: Schema.Type): Boolean = asSchemaList().any { it.type === expectedType } - private fun computeEncodingSteps( + private fun computeEncodingWorkflow( classDescriptor: SerialDescriptor, writerSchema: Schema, - ): Array { + ): EncodingWorkflow { // Encoding steps are ordered regarding the class descriptor and not the writer schema. // Because kotlinx-serialization doesn't provide a way to encode non-sequentially elements. - val encodingSteps = mutableListOf() + val missingWriterFieldsIndexes = mutableListOf() val visitedWriterFields = BooleanArray(writerSchema.fields.size) { false } + val descriptorToWriterFieldIndex = IntArray(classDescriptor.elementsCount) { ReorderingCompositeEncoder.SKIP_ELEMENT_INDEX } + + var expectedNextWriterIndex = 0 classDescriptor.elementNames.forEachIndexed { elementIndex, _ -> val avroFieldName = avro.configuration.fieldNamingStrategy.resolve(classDescriptor, elementIndex) @@ -194,24 +166,32 @@ internal class RecordResolver( if (writerField != null) { visitedWriterFields[writerField.pos()] = true - encodingSteps += - EncodingStep.SerializeWriterField( - elementIndex = elementIndex, - writerFieldIndex = writerField.pos(), - schema = writerField.schema() - ) - } else { - encodingSteps += EncodingStep.IgnoreElement(elementIndex) + descriptorToWriterFieldIndex[elementIndex] = writerField.pos() + if (expectedNextWriterIndex != -1) { + if (writerField.pos() != expectedNextWriterIndex) { + expectedNextWriterIndex = -1 + } else { + expectedNextWriterIndex++ + } + } } } visitedWriterFields.forEachIndexed { writerFieldIndex, visited -> if (!visited) { - encodingSteps += EncodingStep.MissingWriterFieldFailure(writerFieldIndex) + missingWriterFieldsIndexes += writerFieldIndex } } - return encodingSteps.toTypedArray() + return if (missingWriterFieldsIndexes.isNotEmpty()) { + EncodingWorkflow.MissingWriterFields(missingWriterFieldsIndexes) + } else if (expectedNextWriterIndex == -1) { + EncodingWorkflow.NonContiguous(descriptorToWriterFieldIndex) + } else if (classDescriptor.elementsCount != writerSchema.fields.size) { + EncodingWorkflow.ContiguousWithSkips(descriptorToWriterFieldIndex.map { it == ReorderingCompositeEncoder.SKIP_ELEMENT_INDEX }.toBooleanArray()) + } else { + EncodingWorkflow.ExactMatch + } } private fun Schema.tryGetField( @@ -228,33 +208,44 @@ internal class RecordResolver( } } -internal class ClassDescriptorForWriterSchema( - /** - * If true, indicates that the encoding steps are ordered the same as the writer schema fields. - * If false, indicates that the encoding steps are **NOT** ordered the same as the writer schema fields. - */ - val sequentialEncoding: Boolean, +internal class SerializationWorkflow( /** * Decoding steps are ordered regarding the writer schema and not the class descriptor. */ - val decodingSteps: Array, + val decoding: Array, /** * Encoding steps are ordered regarding the class descriptor and not the writer schema. */ - val encodingSteps: Array, + val encoding: EncodingWorkflow, ) { - val hasMissingWriterField by lazy { encodingSteps.any { it is EncodingStep.MissingWriterFieldFailure } } - companion object { val EMPTY = - ClassDescriptorForWriterSchema( - sequentialEncoding = true, - decodingSteps = emptyArray(), - encodingSteps = emptyArray() + SerializationWorkflow( + decoding = emptyArray(), + encoding = EncodingWorkflow.ExactMatch ) } } +internal sealed interface EncodingWorkflow { + /** + * The descriptor elements exactly matches the writer schema fields as a 1-to-1 mapping. + */ + data object ExactMatch : EncodingWorkflow + + class ContiguousWithSkips( + val fieldsToSkip: BooleanArray, + ) : EncodingWorkflow + + class NonContiguous( + val descriptorToWriterFieldIndex: IntArray, + ) : EncodingWorkflow + + class MissingWriterFields( + val missingWriterFields: List, + ) : EncodingWorkflow +} + internal sealed interface DecodingStep { /** * This is a flag indicating that the element is deserializable. @@ -310,31 +301,6 @@ internal sealed interface DecodingStep { ) : DecodingStep } -internal sealed interface EncodingStep { - /** - * The element is present in the writer schema and the class descriptor. - */ - data class SerializeWriterField( - val elementIndex: Int, - val writerFieldIndex: Int, - val schema: Schema, - ) : EncodingStep - - /** - * The element is present in the class descriptor but not in the writer schema, so the element is ignored as nothing has to be serialized. - */ - data class IgnoreElement( - val elementIndex: Int, - ) : EncodingStep - - /** - * The writer field doesn't have a corresponding element in the class descriptor, so we aren't able to serialize a value. - */ - data class MissingWriterFieldFailure( - val writerFieldIndex: Int, - ) : EncodingStep -} - private fun AvroDefault.parseValueToGenericData(schema: Schema): Any? { if (value.isStartingAsJson()) { return Json.parseToJsonElement(value).convertDefaultToObject(schema) @@ -342,8 +308,8 @@ private fun AvroDefault.parseValueToGenericData(schema: Schema): Any? { return JsonPrimitive(value).convertDefaultToObject(schema) } -private fun JsonElement.convertDefaultToObject(schema: Schema): Any? { - return when (this) { +private fun JsonElement.convertDefaultToObject(schema: Schema): Any? = + when (this) { is JsonArray -> when (schema.type) { Schema.Type.ARRAY -> this.map { it.convertDefaultToObject(schema.elementType) } @@ -405,7 +371,6 @@ private fun JsonElement.convertDefaultToObject(schema: Schema): Any? { else -> throw SerializationException("Not a valid primitive value for schema $schema: $this") } } -} private fun Schema.resolveUnion( value: JsonElement?, diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/RecordDirectDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/RecordDirectDecoder.kt index b5bd2cfe..13bf2662 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/RecordDirectDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/RecordDirectDecoder.kt @@ -30,10 +30,10 @@ internal class RecordDirectDecoder( override fun decodeElementIndex(descriptor: SerialDescriptor): Int { var field: DecodingStep while (true) { - if (nextDecodingStepIndex == classDescriptor.decodingSteps.size) { + if (nextDecodingStepIndex == classDescriptor.decoding.size) { return CompositeDecoder.DECODE_DONE } - field = classDescriptor.decodingSteps[nextDecodingStepIndex++] + field = classDescriptor.decoding[nextDecodingStepIndex++] when (field) { is DecodingStep.IgnoreOptionalElement -> { // loop again to ignore the optional element diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/RecordGenericDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/RecordGenericDecoder.kt index 11d77406..84f06c8b 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/RecordGenericDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/RecordGenericDecoder.kt @@ -37,10 +37,10 @@ internal class RecordGenericDecoder( override fun decodeElementIndex(descriptor: SerialDescriptor): Int { var field: DecodingStep do { - if (nextDecodingStep == classDescriptor.decodingSteps.size) { + if (nextDecodingStep == classDescriptor.decoding.size) { return CompositeDecoder.DECODE_DONE } - field = classDescriptor.decodingSteps[nextDecodingStep++] + field = classDescriptor.decoding[nextDecodingStep++] } while (field !is DecodingStep.ValidatedDecodingStep) currentElement = field return field.elementIndex diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt new file mode 100644 index 00000000..75bb1b8e --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt @@ -0,0 +1,372 @@ +package com.github.avrokotlin.avro4k.internal.encoder + +import com.github.avrokotlin.avro4k.AvroEncoder +import com.github.avrokotlin.avro4k.ensureFixedSize +import com.github.avrokotlin.avro4k.fullNameOrAliasMismatchError +import com.github.avrokotlin.avro4k.getIndexTyped +import com.github.avrokotlin.avro4k.internal.SerializerLocatorMiddleware +import com.github.avrokotlin.avro4k.internal.aliases +import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch +import com.github.avrokotlin.avro4k.internal.nonNullSerialName +import com.github.avrokotlin.avro4k.namedSchemaNotFoundInUnionError +import com.github.avrokotlin.avro4k.trySelectEnumSchemaForSymbol +import com.github.avrokotlin.avro4k.trySelectFixedSchemaForSize +import com.github.avrokotlin.avro4k.trySelectNamedSchema +import com.github.avrokotlin.avro4k.trySelectSingleNonNullTypeFromUnion +import com.github.avrokotlin.avro4k.trySelectTypeFromUnion +import com.github.avrokotlin.avro4k.typeNotFoundInUnionError +import com.github.avrokotlin.avro4k.unsupportedWriterTypeError +import kotlinx.serialization.SerializationException +import kotlinx.serialization.SerializationStrategy +import kotlinx.serialization.descriptors.PolymorphicKind +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.encoding.AbstractEncoder +import kotlinx.serialization.encoding.CompositeEncoder +import org.apache.avro.Schema +import org.apache.avro.util.Utf8 + +internal abstract class AbstractAvroEncoder : AbstractEncoder(), AvroEncoder { + private var selectedUnionIndex: Int = -1 + + abstract override var currentWriterSchema: Schema + + abstract fun getRecordEncoder(descriptor: SerialDescriptor): CompositeEncoder + + abstract fun getPolymorphicEncoder(descriptor: SerialDescriptor): CompositeEncoder + + abstract fun getMapEncoder( + descriptor: SerialDescriptor, + collectionSize: Int, + ): CompositeEncoder + + abstract fun getArrayEncoder( + descriptor: SerialDescriptor, + collectionSize: Int, + ): CompositeEncoder + + abstract fun encodeUnionIndexInternal(index: Int) + + abstract fun encodeNullUnchecked() + + abstract fun encodeBooleanUnchecked(value: Boolean) + + abstract fun encodeIntUnchecked(value: Int) + + abstract fun encodeLongUnchecked(value: Long) + + abstract fun encodeFloatUnchecked(value: Float) + + abstract fun encodeDoubleUnchecked(value: Double) + + abstract fun encodeStringUnchecked(value: Utf8) + + abstract fun encodeBytesUnchecked(value: ByteArray) + + abstract fun encodeFixedUnchecked(value: ByteArray) + + abstract fun encodeEnumUnchecked(symbol: String) + + override fun encodeSerializableValue( + serializer: SerializationStrategy, + value: T, + ) { + SerializerLocatorMiddleware.apply(serializer) + .serialize(this, value) + } + + override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { + return when (descriptor.kind) { + StructureKind.CLASS, + StructureKind.OBJECT, + -> { + val nameChecked: Boolean + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectNamedSchema(descriptor).also { nameChecked = it } || + throw namedSchemaNotFoundInUnionError(descriptor.nonNullSerialName, descriptor.aliases) + } else { + nameChecked = false + } + when (currentWriterSchema.type) { + Schema.Type.RECORD -> { + if (nameChecked || currentWriterSchema.isFullNameOrAliasMatch(descriptor)) { + getRecordEncoder(descriptor) + } else { + throw fullNameOrAliasMismatchError(descriptor.nonNullSerialName, descriptor.aliases) + } + } + + else -> throw unsupportedWriterTypeError(Schema.Type.RECORD) + } + } + + is PolymorphicKind -> getPolymorphicEncoder(descriptor) + else -> throw SerializationException("Unsupported structure kind: $descriptor") + } + } + + override fun beginCollection( + descriptor: SerialDescriptor, + collectionSize: Int, + ): CompositeEncoder { + return when (descriptor.kind) { + StructureKind.LIST -> { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.ARRAY) || throw typeNotFoundInUnionError(Schema.Type.ARRAY) + } + when (currentWriterSchema.type) { + Schema.Type.ARRAY -> getArrayEncoder(descriptor, collectionSize) + else -> throw unsupportedWriterTypeError(Schema.Type.ARRAY) + } + } + + StructureKind.MAP -> { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.MAP) || throw typeNotFoundInUnionError(Schema.Type.MAP) + } + when (currentWriterSchema.type) { + Schema.Type.MAP -> getMapEncoder(descriptor, collectionSize) + else -> throw unsupportedWriterTypeError(Schema.Type.MAP) + } + } + + else -> throw SerializationException("Unsupported collection kind: $descriptor") + } + } + + override fun encodeUnionIndex(index: Int) { + if (selectedUnionIndex > -1) { + throw SerializationException("Already selected union index: $selectedUnionIndex, got $index, for selected schema $currentWriterSchema") + } + currentWriterSchema = currentWriterSchema.types[index] + encodeUnionIndexInternal(index) + selectedUnionIndex = index + } + + override fun encodeElement( + descriptor: SerialDescriptor, + index: Int, + ): Boolean { + selectedUnionIndex = -1 + return true + } + + override fun encodeNull() { + if (currentWriterSchema.isUnion) { + // Generally, null types are the first or last in the union + if (currentWriterSchema.types.first().type == Schema.Type.NULL) { + encodeUnionIndex(0) + } else if (currentWriterSchema.types.last().type == Schema.Type.NULL) { + encodeUnionIndex(currentWriterSchema.types.size - 1) + } else { + val nullIndex = + currentWriterSchema.getIndexTyped(Schema.Type.NULL) + ?: throw SerializationException("Cannot encode null value for non-nullable schema: $currentWriterSchema") + encodeUnionIndex(nullIndex) + } + } else if (currentWriterSchema.type != Schema.Type.NULL) { + throw SerializationException("Cannot encode null value for non-null schema: $currentWriterSchema") + } + encodeNullUnchecked() + } + + override fun encodeBytes(value: ByteArray) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.BYTES, Schema.Type.STRING) || + trySelectFixedSchemaForSize(value.size) || + throw typeNotFoundInUnionError(Schema.Type.FIXED, Schema.Type.BYTES, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.BYTES -> encodeBytesUnchecked(value) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value)) + Schema.Type.FIXED -> encodeFixedUnchecked(ensureFixedSize(value)) + else -> throw unsupportedWriterTypeError(Schema.Type.BYTES, Schema.Type.STRING, Schema.Type.FIXED) + } + } + + override fun encodeFixed(value: ByteArray) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectFixedSchemaForSize(value.size) || + trySelectTypeFromUnion(Schema.Type.BYTES, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.FIXED, Schema.Type.BYTES, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.FIXED -> encodeFixedUnchecked(ensureFixedSize(value)) + Schema.Type.BYTES -> encodeBytesUnchecked(value) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value)) + else -> throw unsupportedWriterTypeError(Schema.Type.FIXED, Schema.Type.BYTES, Schema.Type.STRING) + } + } + + override fun encodeEnum( + enumDescriptor: SerialDescriptor, + index: Int, + ) { + val nameChecked: Boolean + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectNamedSchema(enumDescriptor).also { nameChecked = it } || + trySelectTypeFromUnion(Schema.Type.STRING) || + throw namedSchemaNotFoundInUnionError( + enumDescriptor.nonNullSerialName, + enumDescriptor.aliases, + Schema.Type.STRING + ) + } else { + nameChecked = false + } + val enumName = enumDescriptor.getElementName(index) + when (currentWriterSchema.type) { + Schema.Type.ENUM -> + if (nameChecked || currentWriterSchema.isFullNameOrAliasMatch(enumDescriptor)) { + encodeEnumUnchecked(enumName) + } else { + throw fullNameOrAliasMismatchError(enumDescriptor.nonNullSerialName, enumDescriptor.aliases) + } + + Schema.Type.STRING -> encodeStringUnchecked(Utf8(enumName)) + else -> throw unsupportedWriterTypeError(Schema.Type.ENUM, Schema.Type.STRING) + } + } + + override fun encodeBoolean(value: Boolean) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion( + Schema.Type.BOOLEAN, + Schema.Type.STRING + ) || throw typeNotFoundInUnionError(Schema.Type.BOOLEAN, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.BOOLEAN -> encodeBooleanUnchecked(value) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) + else -> throw unsupportedWriterTypeError(Schema.Type.BOOLEAN, Schema.Type.STRING) + } + } + + override fun encodeByte(value: Byte) { + encodeInt(value.toInt()) + } + + override fun encodeShort(value: Short) { + encodeInt(value.toInt()) + } + + override fun encodeInt(value: Int) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion( + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, + Schema.Type.STRING + ) || + throw typeNotFoundInUnionError( + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, + Schema.Type.STRING + ) + } + when (currentWriterSchema.type) { + Schema.Type.INT -> encodeIntUnchecked(value) + Schema.Type.LONG -> encodeLongUnchecked(value.toLong()) + Schema.Type.FLOAT -> encodeFloatUnchecked(value.toFloat()) + Schema.Type.DOUBLE -> encodeDoubleUnchecked(value.toDouble()) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) + else -> throw unsupportedWriterTypeError( + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, + Schema.Type.STRING + ) + } + } + + override fun encodeLong(value: Long) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) || + throw typeNotFoundInUnionError( + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, + Schema.Type.STRING + ) + } + when (currentWriterSchema.type) { + Schema.Type.LONG -> encodeLongUnchecked(value) + Schema.Type.FLOAT -> encodeFloatUnchecked(value.toFloat()) + Schema.Type.DOUBLE -> encodeDoubleUnchecked(value.toDouble()) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) + else -> throw unsupportedWriterTypeError( + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, + Schema.Type.STRING + ) + } + } + + override fun encodeFloat(value: Float) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.FLOAT -> encodeFloatUnchecked(value) + Schema.Type.DOUBLE -> encodeDoubleUnchecked(value.toDouble()) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) + else -> throw unsupportedWriterTypeError(Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) + } + } + + override fun encodeDouble(value: Double) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.DOUBLE, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.DOUBLE, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.DOUBLE -> encodeDoubleUnchecked(value) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) + else -> throw unsupportedWriterTypeError(Schema.Type.DOUBLE, Schema.Type.STRING) + } + } + + override fun encodeChar(value: Char) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.INT, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.INT, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.INT -> encodeIntUnchecked(value.code) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) + else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.STRING) + } + } + + override fun encodeString(value: String) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.STRING, Schema.Type.BYTES) || + trySelectFixedSchemaForSize(value.length) || + trySelectEnumSchemaForSymbol(value) || + throw typeNotFoundInUnionError( + Schema.Type.STRING, + Schema.Type.BYTES, + Schema.Type.FIXED, + Schema.Type.ENUM + ) + } + when (currentWriterSchema.type) { + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value)) + Schema.Type.BYTES -> encodeBytesUnchecked(value.encodeToByteArray()) + Schema.Type.FIXED -> encodeFixedUnchecked(ensureFixedSize(value.encodeToByteArray())) + Schema.Type.ENUM -> encodeEnumUnchecked(value) + else -> throw unsupportedWriterTypeError( + Schema.Type.BYTES, + Schema.Type.STRING, + Schema.Type.FIXED, + Schema.Type.ENUM + ) + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/ReorderingCompositeEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/ReorderingCompositeEncoder.kt new file mode 100644 index 00000000..0f84226d --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/ReorderingCompositeEncoder.kt @@ -0,0 +1,297 @@ +package com.github.avrokotlin.avro4k.internal.encoder + +import com.github.avrokotlin.avro4k.internal.encoder.ReorderingCompositeEncoder.Companion.SKIP_ELEMENT_INDEX +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerializationStrategy +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.encoding.CompositeEncoder +import kotlinx.serialization.encoding.Encoder +import kotlinx.serialization.modules.EmptySerializersModule +import kotlinx.serialization.modules.SerializersModule + +/** + * Encodes composite elements in a specific order managed by [mapElementIndex]. + * + * This encoder will replicate the behavior of a standard encoding, but calling the `encode*Element` methods in + * the order defined by [mapElementIndex]. It first buffers each `encode*Element` calls in an array following + * the given indexes using [mapElementIndex], then when [endStructure] is called, it encodes the buffered calls + * in the expected order by replaying the previous calls on the given [compositeEncoderDelegate]. + * + * When [mapElementIndex] returns [SKIP_ELEMENT_INDEX], the element will be ignored and not encoded. + * + * This encoder is stateful and not designed to be reused. + * + * @param compositeEncoderDelegate the [CompositeEncoder] to be used to encode the given descriptor's elements in the expected order. + * @param encodedElementsCount The final number of elements to encode. If the mapper provides a smaller number of elements, an error will be thrown indicating the missing index. + * @param mapElementIndex maps the element index to a new positional zero-based index. If this mapper provides the same index for multiple elements, only the last one will be encoded as the previous ones will be overridden. The mapped index just helps to reorder the elements, but the reordered `encode*Element` method calls will still pass the original element index. + */ +@ExperimentalSerializationApi +internal class ReorderingCompositeEncoder( + encodedElementsCount: Int, + private val compositeEncoderDelegate: CompositeEncoder, + private val mapElementIndex: (SerialDescriptor, Int) -> Int, +) : CompositeEncoder { + private val bufferedCalls = Array(encodedElementsCount) { null } + + companion object { + @ExperimentalSerializationApi + const val SKIP_ELEMENT_INDEX: Int = -1 + } + + override val serializersModule: SerializersModule + // No need to return a serializers module as it's not used during buffering + get() = EmptySerializersModule() + + private data class BufferedCall( + val originalElementIndex: Int, + val encoder: () -> Unit, + ) + + private fun bufferEncoding( + descriptor: SerialDescriptor, + index: Int, + encoder: () -> Unit, + ) { + val newIndex = mapElementIndex(descriptor, index) + if (newIndex != SKIP_ELEMENT_INDEX) { + bufferedCalls[newIndex] = BufferedCall(index, encoder) + } + } + + override fun endStructure(descriptor: SerialDescriptor) { + bufferedCalls.forEach { fieldToEncode -> + // In case of skipped fields, overridden fields (mapped to same index) or too big [encodedElementsCount], + // the fieldToEncode may be null as no element was encoded for that index + fieldToEncode?.encoder?.invoke() + } + compositeEncoderDelegate.endStructure(descriptor) + } + + override fun encodeBooleanElement( + descriptor: SerialDescriptor, + index: Int, + value: Boolean, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeBooleanElement(descriptor, index, value) + } + } + + override fun encodeByteElement( + descriptor: SerialDescriptor, + index: Int, + value: Byte, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeByteElement(descriptor, index, value) + } + } + + override fun encodeCharElement( + descriptor: SerialDescriptor, + index: Int, + value: Char, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeCharElement(descriptor, index, value) + } + } + + override fun encodeDoubleElement( + descriptor: SerialDescriptor, + index: Int, + value: Double, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeDoubleElement(descriptor, index, value) + } + } + + override fun encodeFloatElement( + descriptor: SerialDescriptor, + index: Int, + value: Float, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeFloatElement(descriptor, index, value) + } + } + + override fun encodeIntElement( + descriptor: SerialDescriptor, + index: Int, + value: Int, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeIntElement(descriptor, index, value) + } + } + + override fun encodeLongElement( + descriptor: SerialDescriptor, + index: Int, + value: Long, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeLongElement(descriptor, index, value) + } + } + + override fun encodeNullableSerializableElement( + descriptor: SerialDescriptor, + index: Int, + serializer: SerializationStrategy, + value: T?, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeNullableSerializableElement(descriptor, index, serializer, value) + } + } + + override fun encodeSerializableElement( + descriptor: SerialDescriptor, + index: Int, + serializer: SerializationStrategy, + value: T, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeSerializableElement(descriptor, index, serializer, value) + } + } + + override fun encodeShortElement( + descriptor: SerialDescriptor, + index: Int, + value: Short, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeShortElement(descriptor, index, value) + } + } + + override fun encodeStringElement( + descriptor: SerialDescriptor, + index: Int, + value: String, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeStringElement(descriptor, index, value) + } + } + + override fun encodeInlineElement( + descriptor: SerialDescriptor, + index: Int, + ): Encoder { + return BufferingInlineEncoder(descriptor, index) + } + + override fun shouldEncodeElementDefault( + descriptor: SerialDescriptor, + index: Int, + ): Boolean { + return compositeEncoderDelegate.shouldEncodeElementDefault(descriptor, index) + } + + private inner class BufferingInlineEncoder( + private val descriptor: SerialDescriptor, + private val elementIndex: Int, + ) : Encoder { + private var encodeNotNullMarkCalled = false + + override val serializersModule: SerializersModule + get() = this@ReorderingCompositeEncoder.serializersModule + + private fun bufferEncoding(encoder: Encoder.() -> Unit) { + bufferEncoding(descriptor, elementIndex) { + compositeEncoderDelegate.encodeInlineElement(descriptor, elementIndex).apply { + if (encodeNotNullMarkCalled) { + encodeNotNullMark() + } + encoder() + } + } + } + + override fun encodeNotNullMark() { + encodeNotNullMarkCalled = true + } + + override fun encodeNullableSerializableValue( + serializer: SerializationStrategy, + value: T?, + ) { + bufferEncoding { encodeNullableSerializableValue(serializer, value) } + } + + override fun encodeSerializableValue( + serializer: SerializationStrategy, + value: T, + ) { + bufferEncoding { encodeSerializableValue(serializer, value) } + } + + override fun encodeBoolean(value: Boolean) { + bufferEncoding { encodeBoolean(value) } + } + + override fun encodeByte(value: Byte) { + bufferEncoding { encodeByte(value) } + } + + override fun encodeChar(value: Char) { + bufferEncoding { encodeChar(value) } + } + + override fun encodeDouble(value: Double) { + bufferEncoding { encodeDouble(value) } + } + + override fun encodeEnum( + enumDescriptor: SerialDescriptor, + index: Int, + ) { + bufferEncoding { encodeEnum(enumDescriptor, index) } + } + + override fun encodeFloat(value: Float) { + bufferEncoding { encodeFloat(value) } + } + + override fun encodeInt(value: Int) { + bufferEncoding { encodeInt(value) } + } + + override fun encodeLong(value: Long) { + bufferEncoding { encodeLong(value) } + } + + @ExperimentalSerializationApi + override fun encodeNull() { + bufferEncoding { encodeNull() } + } + + override fun encodeShort(value: Short) { + bufferEncoding { encodeShort(value) } + } + + override fun encodeString(value: String) { + bufferEncoding { encodeString(value) } + } + + override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { + unexpectedCall(::beginStructure.name) + } + + override fun encodeInline(descriptor: SerialDescriptor): Encoder { + unexpectedCall(::encodeInline.name) + } + + private fun unexpectedCall(methodName: String): Nothing { + // This method is normally called from within encodeSerializableValue or encodeNullableSerializableValue which is buffered, so we should never go here during buffering as it will be delegated to the concrete CompositeEncoder + throw UnsupportedOperationException( + "Non-standard usage of ${CompositeEncoder::class.simpleName}: $methodName should be called from within encodeSerializableValue or encodeNullableSerializableValue" + ) + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/AbstractAvroDirectEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/AbstractAvroDirectEncoder.kt index 87d36beb..57e42aaa 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/AbstractAvroDirectEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/AbstractAvroDirectEncoder.kt @@ -1,24 +1,15 @@ package com.github.avrokotlin.avro4k.internal.encoder.direct import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.AvroEncoder -import com.github.avrokotlin.avro4k.UnionEncoder -import com.github.avrokotlin.avro4k.encodeResolving -import com.github.avrokotlin.avro4k.internal.BadEncodedValueError -import com.github.avrokotlin.avro4k.internal.SerializerLocatorMiddleware -import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch +import com.github.avrokotlin.avro4k.internal.encoder.AbstractAvroEncoder import kotlinx.serialization.SerializationException import kotlinx.serialization.SerializationStrategy -import kotlinx.serialization.descriptors.PolymorphicKind import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.StructureKind import kotlinx.serialization.encoding.AbstractEncoder import kotlinx.serialization.encoding.CompositeEncoder import kotlinx.serialization.modules.SerializersModule import org.apache.avro.Schema -import org.apache.avro.generic.GenericFixed import org.apache.avro.util.Utf8 -import java.nio.ByteBuffer internal class AvroValueDirectEncoder( override var currentWriterSchema: Schema, @@ -29,407 +20,82 @@ internal class AvroValueDirectEncoder( internal sealed class AbstractAvroDirectEncoder( protected val avro: Avro, protected val binaryEncoder: org.apache.avro.io.Encoder, -) : AbstractEncoder(), AvroEncoder, UnionEncoder { - private var selectedUnionIndex: Int = -1 - - abstract override var currentWriterSchema: Schema - +) : AbstractAvroEncoder() { override val serializersModule: SerializersModule get() = avro.serializersModule - override fun encodeSerializableValue( - serializer: SerializationStrategy, - value: T, - ) { - SerializerLocatorMiddleware.apply(serializer) - .serialize(this, value) + override fun getRecordEncoder(descriptor: SerialDescriptor): CompositeEncoder { + return RecordDirectEncoder(descriptor, currentWriterSchema, avro, binaryEncoder) } - override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { - return when (descriptor.kind) { - StructureKind.CLASS, - StructureKind.OBJECT, - -> - encodeResolving( - { BadEncodedValueError(null, currentWriterSchema, Schema.Type.RECORD) } - ) { schema -> - if (schema.type == Schema.Type.RECORD && schema.isFullNameOrAliasMatch(descriptor)) { - { - val elementDescriptors = avro.recordResolver.resolveFields(schema, descriptor) - RecordDirectEncoder(elementDescriptors, schema, avro, binaryEncoder) - } - } else { - null - } - } - - is PolymorphicKind -> PolymorphicDirectEncoder(avro, currentWriterSchema, binaryEncoder) - else -> throw SerializationException("Unsupported structure kind: $descriptor") - } + override fun getPolymorphicEncoder(descriptor: SerialDescriptor): CompositeEncoder { + return PolymorphicDirectEncoder(avro, currentWriterSchema, binaryEncoder) } - override fun beginCollection( + override fun getArrayEncoder( descriptor: SerialDescriptor, collectionSize: Int, ): CompositeEncoder { - return when (descriptor.kind) { - StructureKind.LIST -> - encodeResolving({ BadEncodedValueError(emptyList(), currentWriterSchema, Schema.Type.ARRAY) }) { schema -> - when (schema.type) { - Schema.Type.ARRAY -> { - { ArrayDirectEncoder(schema, collectionSize, avro, binaryEncoder) } - } - - else -> null - } - } - - StructureKind.MAP -> - encodeResolving({ BadEncodedValueError(emptyMap(), currentWriterSchema, Schema.Type.MAP) }) { schema -> - when (schema.type) { - Schema.Type.MAP -> { - { MapDirectEncoder(schema, collectionSize, avro, binaryEncoder) } - } - - else -> null - } - } - - else -> throw SerializationException("Unsupported collection kind: $descriptor") - } - } - - override fun encodeUnionIndex(index: Int) { - if (selectedUnionIndex > -1) { - throw SerializationException("Already selected union index: $selectedUnionIndex, got $index, for selected schema $currentWriterSchema") - } - if (currentWriterSchema.isUnion) { - binaryEncoder.writeIndex(index) - selectedUnionIndex = index - currentWriterSchema = currentWriterSchema.types[index] - } else { - throw SerializationException("Cannot select union index for non-union schema: $currentWriterSchema") - } + return ArrayDirectEncoder(currentWriterSchema, collectionSize, avro, binaryEncoder) } - override fun encodeElement( + override fun getMapEncoder( descriptor: SerialDescriptor, - index: Int, - ): Boolean { - selectedUnionIndex = -1 - return true - } - - override fun encodeNull() { - encodeResolving( - { BadEncodedValueError(null, currentWriterSchema, Schema.Type.NULL) } - ) { - when (it.type) { - Schema.Type.NULL -> { - { binaryEncoder.writeNull() } - } - - else -> null - } - } - } - - override fun encodeBytes(value: ByteBuffer) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.BYTES, Schema.Type.STRING, Schema.Type.FIXED) } - ) { - when (it.type) { - Schema.Type.BYTES -> { - { binaryEncoder.writeBytes(value) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(Utf8(value.array())) } - } - - Schema.Type.FIXED -> { - if (value.remaining() == it.fixedSize) { - { binaryEncoder.writeFixed(value.array()) } - } else { - null - } - } - - else -> null - } - } - } - - override fun encodeBytes(value: ByteArray) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.BYTES, Schema.Type.STRING, Schema.Type.FIXED) } - ) { - when (it.type) { - Schema.Type.BYTES -> { - { binaryEncoder.writeBytes(value) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(Utf8(value)) } - } - - Schema.Type.FIXED -> { - if (value.size == it.fixedSize) { - { binaryEncoder.writeFixed(value) } - } else { - null - } - } - - else -> null - } - } + collectionSize: Int, + ): CompositeEncoder { + return MapDirectEncoder(currentWriterSchema, collectionSize, avro, binaryEncoder) } - override fun encodeFixed(value: GenericFixed) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.FIXED, Schema.Type.STRING, Schema.Type.BYTES) } - ) { - when (it.type) { - Schema.Type.FIXED -> { - if (it.fullName == value.schema.fullName && it.fixedSize == value.bytes().size) { - { binaryEncoder.writeFixed(value.bytes()) } - } else { - null - } - } - - Schema.Type.BYTES -> { - { binaryEncoder.writeBytes(value.bytes()) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(Utf8(value.bytes())) } - } - - else -> null - } - } + override fun encodeUnionIndexInternal(index: Int) { + binaryEncoder.writeIndex(index) } - override fun encodeFixed(value: ByteArray) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.FIXED, Schema.Type.STRING, Schema.Type.BYTES) } - ) { - when (it.type) { - Schema.Type.FIXED -> - if (it.fixedSize == value.size) { - { binaryEncoder.writeFixed(value) } - } else { - null - } - - Schema.Type.BYTES -> { - { binaryEncoder.writeBytes(value) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(Utf8(value)) } - } - - else -> null - } - } + override fun encodeNullUnchecked() { + binaryEncoder.writeNull() } - override fun encodeEnum( - enumDescriptor: SerialDescriptor, - index: Int, - ) { - val enumName = enumDescriptor.getElementName(index) - encodeResolving( - { BadEncodedValueError(index, currentWriterSchema, Schema.Type.ENUM, Schema.Type.STRING) } - ) { - when (it.type) { - Schema.Type.ENUM -> - if (it.isFullNameOrAliasMatch(enumDescriptor)) { - { binaryEncoder.writeEnum(it.getEnumOrdinal(enumName)) } - } else { - null - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(enumName) } - } - - else -> null - } - } + override fun encodeBytesUnchecked(value: ByteArray) { + binaryEncoder.writeBytes(value) } - override fun encodeBoolean(value: Boolean) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.BOOLEAN, Schema.Type.STRING) } - ) { - when (it.type) { - Schema.Type.BOOLEAN -> { - { binaryEncoder.writeBoolean(value) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(value.toString()) } - } - - else -> null - } - } + override fun encodeBooleanUnchecked(value: Boolean) { + binaryEncoder.writeBoolean(value) } - override fun encodeByte(value: Byte) { - encodeInt(value.toInt()) + override fun encodeIntUnchecked(value: Int) { + binaryEncoder.writeInt(value) } - override fun encodeShort(value: Short) { - encodeInt(value.toInt()) + override fun encodeLongUnchecked(value: Long) { + binaryEncoder.writeLong(value) } - override fun encodeInt(value: Int) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) } - ) { - when (it.type) { - Schema.Type.INT -> { - { binaryEncoder.writeInt(value) } - } - - Schema.Type.LONG -> { - { binaryEncoder.writeLong(value.toLong()) } - } - - Schema.Type.FLOAT -> { - { binaryEncoder.writeFloat(value.toFloat()) } - } - - Schema.Type.DOUBLE -> { - { binaryEncoder.writeDouble(value.toDouble()) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(value.toString()) } - } - - else -> null - } - } + override fun encodeFloatUnchecked(value: Float) { + binaryEncoder.writeFloat(value) } - override fun encodeLong(value: Long) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) } - ) { - when (it.type) { - Schema.Type.LONG -> { - { binaryEncoder.writeLong(value) } - } - - Schema.Type.FLOAT -> { - { binaryEncoder.writeFloat(value.toFloat()) } - } - - Schema.Type.DOUBLE -> { - { binaryEncoder.writeDouble(value.toDouble()) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(value.toString()) } - } - - else -> null - } - } + override fun encodeDoubleUnchecked(value: Double) { + binaryEncoder.writeDouble(value) } - override fun encodeFloat(value: Float) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) } - ) { - when (it.type) { - Schema.Type.FLOAT -> { - { binaryEncoder.writeFloat(value) } - } - - Schema.Type.DOUBLE -> { - { binaryEncoder.writeDouble(value.toDouble()) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(value.toString()) } - } - - else -> null - } - } + override fun encodeStringUnchecked(value: Utf8) { + binaryEncoder.writeString(value) } - override fun encodeDouble(value: Double) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.DOUBLE, Schema.Type.STRING) } - ) { - when (it.type) { - Schema.Type.DOUBLE -> { - { binaryEncoder.writeDouble(value) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(value.toString()) } - } - - else -> null - } - } + override fun encodeEnumUnchecked(symbol: String) { + binaryEncoder.writeEnum(currentWriterSchema.getEnumOrdinalChecked(symbol)) } - override fun encodeChar(value: Char) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.INT, Schema.Type.STRING) } - ) { - when (it.type) { - Schema.Type.INT -> { - { binaryEncoder.writeInt(value.code) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(value.toString()) } - } - - else -> null - } - } + override fun encodeFixedUnchecked(value: ByteArray) { + binaryEncoder.writeFixed(value) } +} - override fun encodeString(value: String) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED, Schema.Type.ENUM) } - ) { - when (it.type) { - Schema.Type.STRING -> { - { binaryEncoder.writeString(value) } - } - - Schema.Type.BYTES -> { - { binaryEncoder.writeBytes(value.encodeToByteArray()) } - } - - Schema.Type.FIXED -> { - if (value.length == it.fixedSize) { - { binaryEncoder.writeFixed(value.encodeToByteArray()) } - } else { - null - } - } - - Schema.Type.ENUM -> { - { binaryEncoder.writeEnum(it.getEnumOrdinal(value)) } - } - - else -> null - } - } +private fun Schema.getEnumOrdinalChecked(symbol: String): Int { + return try { + getEnumOrdinal(symbol) + } catch (e: NullPointerException) { + throw SerializationException("Enum symbol $symbol not found in schema $this", e) } } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt index 5c0e1fc8..5c1f67e7 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt @@ -1,39 +1,42 @@ package com.github.avrokotlin.avro4k.internal.encoder.direct import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.AvroEncoder -import com.github.avrokotlin.avro4k.UnionEncoder -import com.github.avrokotlin.avro4k.internal.ClassDescriptorForWriterSchema -import com.github.avrokotlin.avro4k.internal.EncodingStep +import com.github.avrokotlin.avro4k.internal.EncodingWorkflow +import com.github.avrokotlin.avro4k.internal.encoder.ReorderingCompositeEncoder import kotlinx.serialization.SerializationException -import kotlinx.serialization.SerializationStrategy import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.encoding.AbstractEncoder import kotlinx.serialization.encoding.CompositeEncoder -import kotlinx.serialization.modules.SerializersModule import org.apache.avro.Schema -import org.apache.avro.generic.GenericFixed -import java.nio.ByteBuffer @Suppress("FunctionName") internal fun RecordDirectEncoder( - classDescriptor: ClassDescriptorForWriterSchema, + descriptor: SerialDescriptor, schema: Schema, avro: Avro, binaryEncoder: org.apache.avro.io.Encoder, ): CompositeEncoder { - return if (classDescriptor.sequentialEncoding) { - RecordSequentialDirectEncoder(classDescriptor, schema, avro, binaryEncoder) - } else { - RecordBadOrderDirectEncoder(classDescriptor, schema, avro, binaryEncoder) + val encodingWorkflow = avro.recordResolver.resolveFields(schema, descriptor).encoding + when (encodingWorkflow) { + is EncodingWorkflow.ExactMatch -> return RecordExactDirectEncoder(schema, avro, binaryEncoder) + is EncodingWorkflow.ContiguousWithSkips -> return RecordSkippingDirectEncoder(encodingWorkflow.fieldsToSkip, schema, avro, binaryEncoder) + is EncodingWorkflow.NonContiguous -> return ReorderingCompositeEncoder( + schema.fields.size, + RecordNonContiguousDirectEncoder( + encodingWorkflow.descriptorToWriterFieldIndex, + schema, + avro, + binaryEncoder + ) + ) { _, index -> + encodingWorkflow.descriptorToWriterFieldIndex[index] + } + + is EncodingWorkflow.MissingWriterFields -> throw SerializationException("Invalid encoding workflow") } } -/** - * Consider that the descriptor elements are in the same order as the schema fields, and all the fields are represented by an element. - */ -private class RecordSequentialDirectEncoder( - private val classDescriptor: ClassDescriptorForWriterSchema, +private class RecordNonContiguousDirectEncoder( + private val descriptorToWriterFieldIndex: IntArray, private val schema: Schema, avro: Avro, binaryEncoder: org.apache.avro.io.Encoder, @@ -44,167 +47,50 @@ private class RecordSequentialDirectEncoder( descriptor: SerialDescriptor, index: Int, ): Boolean { - super.encodeElement(descriptor, index) - // index == elementIndex == writerFieldIndex, so the written field is already in the good order - return when (val step = classDescriptor.encodingSteps[index]) { - is EncodingStep.SerializeWriterField -> { - currentWriterSchema = schema.fields[step.writerFieldIndex].schema() - true - } - - is EncodingStep.IgnoreElement -> { - false - } - - is EncodingStep.MissingWriterFieldFailure -> { - throw SerializationException("No serializable element found for writer field ${step.writerFieldIndex} in schema $schema") - } - } - } - - override fun endStructure(descriptor: SerialDescriptor) { - if (classDescriptor.hasMissingWriterField) { - throw SerializationException("The descriptor is not writing all the expected fields of writer schema. Schema: $schema, descriptor: $descriptor") + val writerFieldIndex = descriptorToWriterFieldIndex[index] + if (writerFieldIndex == -1) { + return false } + super.encodeElement(descriptor, index) + currentWriterSchema = schema.fields[writerFieldIndex].schema() + return true } } -/** - * This handles the case where the descriptor elements are not in the same order as the schema fields. - * - * First we buffer all the element encodings to the corresponding field indexes, then we encode them for real in the correct order using [RecordSequentialDirectEncoder]. - * - * Not implementing [UnionEncoder] as all the encoding is delegated to the [RecordSequentialDirectEncoder] which already handles union encoding. - */ -private class RecordBadOrderDirectEncoder( - private val classDescriptor: ClassDescriptorForWriterSchema, +private class RecordSkippingDirectEncoder( + private val skippedElements: BooleanArray, private val schema: Schema, - private val avro: Avro, - private val binaryEncoder: org.apache.avro.io.Encoder, -) : AbstractEncoder(), AvroEncoder { - // Each time we encode a field, if the next expected schema field index is not the good one, it is buffered until it's the time to encode it - private var bufferedFields = Array(schema.fields.size) { null } - private lateinit var encodingStepToBuffer: EncodingStep.SerializeWriterField - - data class BufferedField( - val step: EncodingStep.SerializeWriterField, - val encoder: AvroEncoder.() -> Unit, - ) - - override val currentWriterSchema: Schema - get() = encodingStepToBuffer.schema - - override val serializersModule: SerializersModule - get() = avro.serializersModule + avro: Avro, + binaryEncoder: org.apache.avro.io.Encoder, +) : AbstractAvroDirectEncoder(avro, binaryEncoder) { + override lateinit var currentWriterSchema: Schema override fun encodeElement( descriptor: SerialDescriptor, index: Int, ): Boolean { - return when (val step = classDescriptor.encodingSteps[index]) { - is EncodingStep.SerializeWriterField -> { - encodingStepToBuffer = step - true - } - - is EncodingStep.IgnoreElement -> { - false - } - - is EncodingStep.MissingWriterFieldFailure -> { - throw SerializationException("No serializable element found for writer field ${step.writerFieldIndex} in schema $schema") - } - } - } - - private inline fun bufferEncoding(crossinline encoder: AvroEncoder.() -> Unit) { - bufferedFields[encodingStepToBuffer.writerFieldIndex] = BufferedField(encodingStepToBuffer) { encoder() } - } - - override fun endStructure(descriptor: SerialDescriptor) { - encodeBufferedFields(descriptor) - } - - private fun encodeBufferedFields(descriptor: SerialDescriptor) { - val recordEncoder = RecordSequentialDirectEncoder(classDescriptor, schema, avro, binaryEncoder) - bufferedFields.forEach { fieldToEncode -> - if (fieldToEncode == null) { - throw SerializationException("The writer field is missing in the buffered fields, it hasn't been encoded yet") - } - // To simulate the behavior of regular element encoding - // We don't use the return of encodeElement because we know it's always true - recordEncoder.encodeElement(descriptor, fieldToEncode.step.elementIndex) - fieldToEncode.encoder(recordEncoder) + if (skippedElements[index]) { + return false } + super.encodeElement(descriptor, index) + currentWriterSchema = schema.fields[index].schema() + return true } +} - override fun encodeSerializableValue( - serializer: SerializationStrategy, - value: T, - ) { - bufferEncoding { encodeSerializableValue(serializer, value) } - } - - override fun encodeNull() { - bufferEncoding { encodeNull() } - } - - override fun encodeBytes(value: ByteArray) { - bufferEncoding { encodeBytes(value) } - } - - override fun encodeBytes(value: ByteBuffer) { - bufferEncoding { encodeBytes(value) } - } - - override fun encodeFixed(value: GenericFixed) { - bufferEncoding { encodeFixed(value) } - } - - override fun encodeFixed(value: ByteArray) { - bufferEncoding { encodeFixed(value) } - } - - override fun encodeBoolean(value: Boolean) { - bufferEncoding { encodeBoolean(value) } - } - - override fun encodeByte(value: Byte) { - bufferEncoding { encodeByte(value) } - } - - override fun encodeShort(value: Short) { - bufferEncoding { encodeShort(value) } - } - - override fun encodeInt(value: Int) { - bufferEncoding { encodeInt(value) } - } - - override fun encodeLong(value: Long) { - bufferEncoding { encodeLong(value) } - } - - override fun encodeFloat(value: Float) { - bufferEncoding { encodeFloat(value) } - } - - override fun encodeDouble(value: Double) { - bufferEncoding { encodeDouble(value) } - } - - override fun encodeChar(value: Char) { - bufferEncoding { encodeChar(value) } - } - - override fun encodeString(value: String) { - bufferEncoding { encodeString(value) } - } +private class RecordExactDirectEncoder( + private val schema: Schema, + avro: Avro, + binaryEncoder: org.apache.avro.io.Encoder, +) : AbstractAvroDirectEncoder(avro, binaryEncoder) { + override lateinit var currentWriterSchema: Schema - override fun encodeEnum( - enumDescriptor: SerialDescriptor, + override fun encodeElement( + descriptor: SerialDescriptor, index: Int, - ) { - bufferEncoding { encodeEnum(enumDescriptor, index) } + ): Boolean { + super.encodeElement(descriptor, index) + currentWriterSchema = schema.fields[index].schema() + return true } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt index a064891b..06a5f416 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt @@ -1,440 +1,81 @@ package com.github.avrokotlin.avro4k.internal.encoder.generic import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.AvroEncoder -import com.github.avrokotlin.avro4k.UnionEncoder -import com.github.avrokotlin.avro4k.encodeResolving -import com.github.avrokotlin.avro4k.internal.BadEncodedValueError -import com.github.avrokotlin.avro4k.internal.SerializerLocatorMiddleware -import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch -import com.github.avrokotlin.avro4k.internal.toIntExact -import kotlinx.serialization.SerializationException -import kotlinx.serialization.SerializationStrategy -import kotlinx.serialization.descriptors.PolymorphicKind +import com.github.avrokotlin.avro4k.internal.encoder.AbstractAvroEncoder import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.encoding.AbstractEncoder import kotlinx.serialization.encoding.CompositeEncoder import kotlinx.serialization.modules.SerializersModule -import org.apache.avro.Schema import org.apache.avro.generic.GenericData -import org.apache.avro.generic.GenericFixed +import org.apache.avro.util.Utf8 import java.nio.ByteBuffer -internal abstract class AbstractAvroGenericEncoder : AbstractEncoder(), AvroEncoder, UnionEncoder { +internal abstract class AbstractAvroGenericEncoder : AbstractAvroEncoder() { abstract val avro: Avro - abstract override var currentWriterSchema: Schema - abstract override fun encodeValue(value: Any) - abstract override fun encodeNull() - - override fun encodeElement( - descriptor: SerialDescriptor, - index: Int, - ): Boolean { - selectedUnionIndex = -1 - return true - } - - private var selectedUnionIndex: Int = -1 - - override fun encodeUnionIndex(index: Int) { - if (selectedUnionIndex > -1) { - throw SerializationException("Already selected union index: $selectedUnionIndex, got $index, for selected schema $currentWriterSchema") - } - if (currentWriterSchema.isUnion) { - selectedUnionIndex = index - currentWriterSchema = currentWriterSchema.types[index] - } else { - throw SerializationException("Cannot select union index for non-union schema: $currentWriterSchema") - } - } - override val serializersModule: SerializersModule get() = avro.serializersModule - override fun encodeSerializableValue( - serializer: SerializationStrategy, - value: T, - ) { - SerializerLocatorMiddleware.apply(serializer) - .serialize(this, value) + override fun getRecordEncoder(descriptor: SerialDescriptor): CompositeEncoder { + return RecordGenericEncoder(avro, descriptor, currentWriterSchema) { encodeValue(it) } } - override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { - return when (descriptor.kind) { - StructureKind.CLASS, - StructureKind.OBJECT, - -> - encodeResolving( - { BadEncodedValueError(null, currentWriterSchema, Schema.Type.RECORD) } - ) { schema -> - if (schema.type == Schema.Type.RECORD && schema.isFullNameOrAliasMatch(descriptor)) { - { RecordGenericEncoder(avro, descriptor, schema) { encodeValue(it) } } - } else { - null - } - } - - is PolymorphicKind -> - PolymorphicEncoder(avro, currentWriterSchema) { - encodeValue(it) - } - - else -> throw SerializationException("Unsupported structure kind: $descriptor") - } + override fun getPolymorphicEncoder(descriptor: SerialDescriptor): CompositeEncoder { + return PolymorphicEncoder(avro, currentWriterSchema) { encodeValue(it) } } - override fun beginCollection( + override fun getArrayEncoder( descriptor: SerialDescriptor, collectionSize: Int, ): CompositeEncoder { - return when (descriptor.kind) { - StructureKind.LIST -> - encodeResolving( - { BadEncodedValueError(emptyList(), currentWriterSchema, Schema.Type.ARRAY, Schema.Type.BYTES, Schema.Type.FIXED) } - ) { schema -> - when (schema.type) { - Schema.Type.ARRAY -> { - { ArrayGenericEncoder(avro, collectionSize, schema) { encodeValue(it) } } - } - - Schema.Type.BYTES -> { - { BytesGenericEncoder(avro, collectionSize) { encodeValue(it) } } - } - - Schema.Type.FIXED -> { - { FixedGenericEncoder(avro, collectionSize, schema) { encodeValue(it) } } - } - - else -> null - } - } - - StructureKind.MAP -> - encodeResolving( - { BadEncodedValueError(emptyMap(), currentWriterSchema, Schema.Type.MAP) } - ) { schema -> - when (schema.type) { - Schema.Type.MAP -> { - { MapGenericEncoder(avro, collectionSize, schema) { encodeValue(it) } } - } - - else -> null - } - } - - else -> throw SerializationException("Unsupported collection kind: $descriptor") - } + return ArrayGenericEncoder(avro, collectionSize, currentWriterSchema) { encodeValue(it) } } - override fun encodeBytes(value: ByteBuffer) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED) } - ) { schema -> - when (schema.type) { - Schema.Type.BYTES -> { - { encodeValue(value) } - } - - Schema.Type.FIXED -> { - if (value.remaining() == schema.fixedSize) { - { encodeValue(value.array()) } - } else { - null - } - } - - Schema.Type.STRING -> { - { encodeValue(value.array().decodeToString()) } - } - - else -> null - } - } - } - - override fun encodeBytes(value: ByteArray) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED) } - ) { schema -> - when (schema.type) { - Schema.Type.BYTES -> { - { encodeValue(ByteBuffer.wrap(value)) } - } - - Schema.Type.FIXED -> { - if (value.size == schema.fixedSize) { - { encodeValue(value) } - } else { - null - } - } - - Schema.Type.STRING -> { - { encodeValue(value.decodeToString()) } - } - - else -> null - } - } - } - - override fun encodeFixed(value: GenericFixed) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED) } - ) { schema -> - when (schema.type) { - Schema.Type.FIXED -> - if (schema.fullName == value.schema.fullName && schema.fixedSize == value.bytes().size) { - { encodeValue(value) } - } else { - null - } - - Schema.Type.BYTES -> { - { encodeValue(ByteBuffer.wrap(value.bytes())) } - } - - Schema.Type.STRING -> { - { encodeValue(value.bytes().decodeToString()) } - } - - else -> null - } - } - } - - override fun encodeFixed(value: ByteArray) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED) } - ) { schema -> - when (schema.type) { - Schema.Type.FIXED -> { - if (value.size == schema.fixedSize) { - { encodeValue(value) } - } else { - null - } - } - - Schema.Type.BYTES -> { - { encodeValue(ByteBuffer.wrap(value)) } - } - - Schema.Type.STRING -> { - { encodeValue(value.decodeToString()) } - } - - else -> null - } - } + override fun getMapEncoder( + descriptor: SerialDescriptor, + collectionSize: Int, + ): CompositeEncoder { + return MapGenericEncoder(avro, collectionSize, currentWriterSchema) { encodeValue(it) } } - override fun encodeBoolean(value: Boolean) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.BOOLEAN, Schema.Type.STRING) } - ) { schema -> - when (schema.type) { - Schema.Type.BOOLEAN -> { - { encodeValue(value) } - } - - Schema.Type.STRING -> { - { encodeValue(value.toString()) } - } - - else -> null - } - } + override fun encodeBytesUnchecked(value: ByteArray) { + encodeValue(ByteBuffer.wrap(value)) } - override fun encodeByte(value: Byte) { - encodeInt(value.toInt()) + override fun encodeBooleanUnchecked(value: Boolean) { + encodeValue(value) } - override fun encodeShort(value: Short) { - encodeInt(value.toInt()) + override fun encodeStringUnchecked(value: Utf8) { + encodeValue(value) } - override fun encodeInt(value: Int) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.LONG, Schema.Type.INT, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) } - ) { schema -> - when (schema.type) { - Schema.Type.INT -> { - { encodeValue(value) } - } - - Schema.Type.LONG -> { - { encodeValue(value.toLong()) } - } - - Schema.Type.FLOAT -> { - { encodeValue(value.toFloat()) } - } - - Schema.Type.DOUBLE -> { - { encodeValue(value.toDouble()) } - } - - Schema.Type.STRING -> { - { encodeValue(value.toString()) } - } - - else -> null - } - } + override fun encodeUnionIndexInternal(index: Int) { + // nothing to do } - override fun encodeLong(value: Long) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.LONG, Schema.Type.INT, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) } - ) { schema -> - when (schema.type) { - Schema.Type.LONG -> { - { encodeValue(value) } - } - - Schema.Type.INT -> { - { encodeValue(value.toIntExact()) } - } - - Schema.Type.FLOAT -> { - { encodeValue(value.toFloat()) } - } - - Schema.Type.DOUBLE -> { - { encodeValue(value.toDouble()) } - } - - Schema.Type.STRING -> { - { encodeValue(value.toString()) } - } - - else -> null - } - } + override fun encodeFixedUnchecked(value: ByteArray) { + encodeValue(GenericData.Fixed(currentWriterSchema, value)) } - override fun encodeFloat(value: Float) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.DOUBLE, Schema.Type.FLOAT) } - ) { schema -> - when (schema.type) { - Schema.Type.FLOAT -> { - { encodeValue(value) } - } - - Schema.Type.DOUBLE -> { - { encodeValue(value.toDouble()) } - } - - Schema.Type.STRING -> { - { encodeValue(value.toString()) } - } - - else -> null - } - } + override fun encodeIntUnchecked(value: Int) { + encodeValue(value) } - override fun encodeDouble(value: Double) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.DOUBLE) } - ) { schema -> - when (schema.type) { - Schema.Type.DOUBLE -> { - { encodeValue(value) } - } - - Schema.Type.STRING -> { - { encodeValue(value.toString()) } - } - - else -> null - } - } + override fun encodeLongUnchecked(value: Long) { + encodeValue(value) } - override fun encodeChar(value: Char) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.INT, Schema.Type.STRING) } - ) { schema -> - when (schema.type) { - Schema.Type.INT -> { - { encodeValue(value.code) } - } - - Schema.Type.STRING -> { - { encodeValue(value.toString()) } - } - - else -> null - } - } + override fun encodeFloatUnchecked(value: Float) { + encodeValue(value) } - override fun encodeString(value: String) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED, Schema.Type.ENUM) } - ) { schema -> - when (schema.type) { - Schema.Type.STRING -> { - { encodeValue(value) } - } - - Schema.Type.BYTES -> { - { encodeValue(value.encodeToByteArray()) } - } - - Schema.Type.FIXED -> { - if (value.length == schema.fixedSize) { - { encodeValue(value.encodeToByteArray()) } - } else { - null - } - } - - Schema.Type.ENUM -> { - { encodeValue(GenericData.EnumSymbol(schema, value)) } - } - - else -> null - } - } + override fun encodeDoubleUnchecked(value: Double) { + encodeValue(value) } - override fun encodeEnum( - enumDescriptor: SerialDescriptor, - index: Int, - ) { - /* - We allow enums as ENUM (must match the descriptor's full name), STRING or UNION. - For UNION, we look for an enum with the descriptor's full name, otherwise a string. - */ - val value = enumDescriptor.getElementName(index) - - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.ENUM) } - ) { schema -> - when (schema.type) { - Schema.Type.STRING -> { - { encodeValue(value) } - } - - Schema.Type.ENUM -> { - if (schema.isFullNameOrAliasMatch(enumDescriptor)) { - { encodeValue(GenericData.EnumSymbol(schema, value)) } - } else { - null - } - } - - else -> null - } - } + override fun encodeEnumUnchecked(symbol: String) { + encodeValue(GenericData.EnumSymbol(currentWriterSchema, symbol)) } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/ArrayGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/ArrayGenericEncoder.kt index 925d252b..b00a241d 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/ArrayGenericEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/ArrayGenericEncoder.kt @@ -1,8 +1,6 @@ package com.github.avrokotlin.avro4k.internal.encoder.generic import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.encodeResolving -import com.github.avrokotlin.avro4k.internal.BadEncodedValueError import kotlinx.serialization.descriptors.SerialDescriptor import org.apache.avro.Schema import org.apache.avro.generic.GenericArray @@ -36,17 +34,7 @@ internal class ArrayGenericEncoder( values[index++] = value } - override fun encodeNull() { - encodeResolving( - { BadEncodedValueError(null, currentWriterSchema, Schema.Type.NULL) } - ) { - when (it.type) { - Schema.Type.NULL -> { - { values[index++] = null } - } - - else -> null - } - } + override fun encodeNullUnchecked() { + values[index++] = null } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AvroValueGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AvroValueGenericEncoder.kt index eea26709..20901c98 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AvroValueGenericEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AvroValueGenericEncoder.kt @@ -1,8 +1,6 @@ package com.github.avrokotlin.avro4k.internal.encoder.generic import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.encodeResolving -import com.github.avrokotlin.avro4k.internal.BadEncodedValueError import org.apache.avro.Schema internal class AvroValueGenericEncoder( @@ -14,16 +12,7 @@ internal class AvroValueGenericEncoder( onEncoded(value) } - override fun encodeNull() { - encodeResolving( - { BadEncodedValueError(null, currentWriterSchema, Schema.Type.NULL) } - ) { - when (it.type) { - Schema.Type.NULL -> { - { onEncoded(null) } - } - else -> null - } - } + override fun encodeNullUnchecked() { + onEncoded(null) } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/BytesGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/BytesGenericEncoder.kt deleted file mode 100644 index 8a0f9161..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/BytesGenericEncoder.kt +++ /dev/null @@ -1,26 +0,0 @@ -package com.github.avrokotlin.avro4k.internal.encoder.generic - -import com.github.avrokotlin.avro4k.Avro -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.encoding.AbstractEncoder -import kotlinx.serialization.modules.SerializersModule -import java.nio.ByteBuffer - -internal class BytesGenericEncoder( - private val avro: Avro, - arraySize: Int, - private val onEncoded: (ByteBuffer) -> Unit, -) : AbstractEncoder() { - private val output: ByteBuffer = ByteBuffer.allocate(arraySize) - - override val serializersModule: SerializersModule - get() = avro.serializersModule - - override fun endStructure(descriptor: SerialDescriptor) { - onEncoded(output.rewind()) - } - - override fun encodeByte(value: Byte) { - output.put(value) - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/FixedGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/FixedGenericEncoder.kt deleted file mode 100644 index e6f02b1c..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/FixedGenericEncoder.kt +++ /dev/null @@ -1,37 +0,0 @@ -package com.github.avrokotlin.avro4k.internal.encoder.generic - -import com.github.avrokotlin.avro4k.Avro -import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.encoding.AbstractEncoder -import kotlinx.serialization.modules.SerializersModule -import org.apache.avro.Schema -import org.apache.avro.generic.GenericData -import org.apache.avro.generic.GenericFixed - -internal class FixedGenericEncoder( - private val avro: Avro, - arraySize: Int, - private val schema: Schema, - private val onEncoded: (GenericFixed) -> Unit, -) : AbstractEncoder() { - private val buffer = ByteArray(schema.fixedSize) - private var pos = 0 - - init { - if (arraySize != schema.fixedSize) { - throw SerializationException("Actual collection size $arraySize is greater than schema fixed size $schema") - } - } - - override val serializersModule: SerializersModule - get() = avro.serializersModule - - override fun endStructure(descriptor: SerialDescriptor) { - onEncoded(GenericData.Fixed(schema, buffer)) - } - - override fun encodeByte(value: Byte) { - buffer[pos++] = value - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/MapGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/MapGenericEncoder.kt index c5f98f46..3954d49a 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/MapGenericEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/MapGenericEncoder.kt @@ -46,7 +46,7 @@ internal class MapGenericEncoder( } } - override fun encodeNull() { + override fun encodeNullUnchecked() { val key = currentKey ?: throw SerializationException("Map key cannot be null") entries.add(key to null) } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt index 1ae05fa7..5f7706ad 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt @@ -2,7 +2,7 @@ package com.github.avrokotlin.avro4k.internal.encoder.generic import com.github.avrokotlin.avro4k.Avro import com.github.avrokotlin.avro4k.ListRecord -import com.github.avrokotlin.avro4k.internal.EncodingStep +import com.github.avrokotlin.avro4k.internal.EncodingWorkflow import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.SerialDescriptor import org.apache.avro.Schema @@ -16,7 +16,7 @@ internal class RecordGenericEncoder( ) : AbstractAvroGenericEncoder() { private val fieldValues: Array = Array(schema.fields.size) { null } - private val classDescriptor = avro.recordResolver.resolveFields(schema, descriptor) + private val encodingWorkflow = avro.recordResolver.resolveFields(schema, descriptor).encoding private lateinit var currentField: Schema.Field override lateinit var currentWriterSchema: Schema @@ -26,22 +26,31 @@ internal class RecordGenericEncoder( index: Int, ): Boolean { super.encodeElement(descriptor, index) - return when (val step = classDescriptor.encodingSteps[index]) { - is EncodingStep.SerializeWriterField -> { - val field = schema.fields[step.writerFieldIndex] - currentField = field - currentWriterSchema = field.schema() - true - } + val writerFieldIndex = + when (encodingWorkflow) { + EncodingWorkflow.ExactMatch -> index - is EncodingStep.IgnoreElement -> { - false - } + is EncodingWorkflow.ContiguousWithSkips -> { + if (encodingWorkflow.fieldsToSkip[index]) { + return false + } + index + } + + is EncodingWorkflow.NonContiguous -> { + val writerFieldIndex = encodingWorkflow.descriptorToWriterFieldIndex[index] + if (writerFieldIndex == -1) { + return false + } + writerFieldIndex + } - is EncodingStep.MissingWriterFieldFailure -> { - throw SerializationException("No serializable element found for writer field ${step.writerFieldIndex} in schema $schema") + is EncodingWorkflow.MissingWriterFields -> throw SerializationException("Invalid encoding workflow") } - } + val field = schema.fields[writerFieldIndex] + currentField = field + currentWriterSchema = field.schema() + return true } override fun endStructure(descriptor: SerialDescriptor) { @@ -52,7 +61,7 @@ internal class RecordGenericEncoder( fieldValues[currentField.pos()] = value } - override fun encodeNull() { + override fun encodeNullUnchecked() { fieldValues[currentField.pos()] = null } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/exceptions.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/exceptions.kt index 1455fd3d..aac2364c 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/exceptions.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/exceptions.kt @@ -7,7 +7,6 @@ import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.descriptors.SerialKind import kotlinx.serialization.encoding.Decoder -import kotlinx.serialization.encoding.Encoder import org.apache.avro.Schema import kotlin.reflect.KClass @@ -73,21 +72,4 @@ internal fun AvroDecoder.UnexpectedDecodeSchemaError( return SerializationException( "For $actualType, expected type one of $allExpectedTypes, but had writer schema $currentWriterSchema" ) -} - -context(Encoder) -internal fun BadEncodedValueError( - value: Any?, - writerSchema: Schema, - firstExpectedType: Schema.Type, - vararg expectedTypes: Schema.Type, -): SerializationException { - val allExpectedTypes = listOf(firstExpectedType) + expectedTypes - return if (value == null) { - SerializationException("Encoded null value, expected one of $allExpectedTypes, actual writer schema $writerSchema") - } else { - SerializationException( - "Encoded value '$value' of type ${value::class.qualifiedName}, expected one of $allExpectedTypes, actual writer schema $writerSchema" - ) - } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt index 3ecfdd3b..079cc23a 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt @@ -52,7 +52,14 @@ internal fun Schema.isNamedSchema(): Boolean { } internal fun Schema.isFullNameOrAliasMatch(descriptor: SerialDescriptor): Boolean { - return isFullNameMatch(descriptor.nonNullSerialName) || descriptor.aliases.any { isFullNameMatch(it) } + return isFullNameOrAliasMatch(descriptor.nonNullSerialName, descriptor::aliases) +} + +internal fun Schema.isFullNameOrAliasMatch( + fullName: String, + aliases: () -> Set, +): Boolean { + return isFullNameMatch(fullName) || aliases().any { isFullNameMatch(it) } } internal fun Schema.isFullNameMatch(fullNameToMatch: String): Boolean { diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroDuration.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroDuration.kt index 21c8e433..7761a3ca 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroDuration.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroDuration.kt @@ -4,9 +4,16 @@ import com.github.avrokotlin.avro4k.AnyValueDecoder import com.github.avrokotlin.avro4k.AvroDecoder import com.github.avrokotlin.avro4k.AvroEncoder import com.github.avrokotlin.avro4k.decodeResolvingAny -import com.github.avrokotlin.avro4k.encodeResolving -import com.github.avrokotlin.avro4k.internal.BadEncodedValueError +import com.github.avrokotlin.avro4k.ensureFixedSize +import com.github.avrokotlin.avro4k.fullNameOrAliasMismatchError import com.github.avrokotlin.avro4k.internal.UnexpectedDecodeSchemaError +import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch +import com.github.avrokotlin.avro4k.trySelectLogicalTypeFromUnion +import com.github.avrokotlin.avro4k.trySelectNamedSchema +import com.github.avrokotlin.avro4k.trySelectSingleNonNullTypeFromUnion +import com.github.avrokotlin.avro4k.trySelectTypeFromUnion +import com.github.avrokotlin.avro4k.typeNotFoundInUnionError +import com.github.avrokotlin.avro4k.unsupportedWriterTypeError import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.Serializable import kotlinx.serialization.SerializationException @@ -116,8 +123,9 @@ public class AvroDurationParseException(value: String) : SerializationException( internal object AvroDurationSerializer : AvroSerializer(AvroDuration::class.qualifiedName!!) { private const val LOGICAL_TYPE_NAME = "duration" private const val DURATION_BYTES = 12 + private const val DEFAULT_DURATION_FULL_NAME = "time.Duration" internal val DURATION_SCHEMA = - Schema.createFixed("time.Duration", "A 12-byte byte array encoding a duration in months, days and milliseconds.", null, DURATION_BYTES).also { + Schema.createFixed(DEFAULT_DURATION_FULL_NAME, "A 12-byte byte array encoding a duration in months, days and milliseconds.", null, DURATION_BYTES).also { LogicalType(LOGICAL_TYPE_NAME).addToSchema(it) } @@ -132,21 +140,22 @@ internal object AvroDurationSerializer : AvroSerializer(AvroDurati value: AvroDuration, ) { with(encoder) { - encodeResolving({ BadEncodedValueError(value, currentWriterSchema, Schema.Type.FIXED, Schema.Type.STRING) }) { - when (it.type) { - Schema.Type.FIXED -> - if (it.logicalType?.name == LOGICAL_TYPE_NAME && it.fixedSize == DURATION_BYTES) { - { encodeFixed(encodeDuration(value)) } - } else { - null - } - - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectNamedSchema(DEFAULT_DURATION_FULL_NAME) || + trySelectLogicalTypeFromUnion(LOGICAL_TYPE_NAME, Schema.Type.FIXED) || + trySelectTypeFromUnion(Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.FIXED, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.FIXED -> + if (currentWriterSchema.logicalType?.name == LOGICAL_TYPE_NAME || currentWriterSchema.isFullNameOrAliasMatch(DEFAULT_DURATION_FULL_NAME, ::emptySet)) { + encodeFixed(ensureFixedSize(encodeDuration(value))) + } else { + throw fullNameOrAliasMismatchError(DEFAULT_DURATION_FULL_NAME, emptySet()) } - else -> null - } + Schema.Type.STRING -> encodeString(value.toString()) + else -> throw unsupportedWriterTypeError(Schema.Type.FIXED, Schema.Type.STRING) } } } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaStdLibSerializers.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaStdLibSerializers.kt index 5eea0d1d..f4cf796c 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaStdLibSerializers.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaStdLibSerializers.kt @@ -5,11 +5,14 @@ import com.github.avrokotlin.avro4k.AvroDecimal import com.github.avrokotlin.avro4k.AvroDecoder import com.github.avrokotlin.avro4k.AvroEncoder import com.github.avrokotlin.avro4k.decodeResolvingAny -import com.github.avrokotlin.avro4k.encodeResolving import com.github.avrokotlin.avro4k.internal.AvroSchemaGenerationException -import com.github.avrokotlin.avro4k.internal.BadEncodedValueError import com.github.avrokotlin.avro4k.internal.UnexpectedDecodeSchemaError import com.github.avrokotlin.avro4k.internal.copy +import com.github.avrokotlin.avro4k.trySelectLogicalTypeFromUnion +import com.github.avrokotlin.avro4k.trySelectSingleNonNullTypeFromUnion +import com.github.avrokotlin.avro4k.trySelectTypeFromUnion +import com.github.avrokotlin.avro4k.typeNotFoundInUnionError +import com.github.avrokotlin.avro4k.unsupportedWriterTypeError import kotlinx.serialization.KSerializer import kotlinx.serialization.descriptors.PrimitiveKind import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor @@ -91,41 +94,18 @@ public object BigIntegerSerializer : AvroSerializer(BigInteger::clas encoder: AvroEncoder, value: BigInteger, ) { - encoder.encodeResolving({ - with(encoder) { - BadEncodedValueError( - value, - encoder.currentWriterSchema, - Schema.Type.STRING, - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE - ) + with(encoder) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.STRING, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE) || + throw typeNotFoundInUnionError(Schema.Type.STRING, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE) } - }) { schema -> - when (schema.type) { - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } - } - - Schema.Type.INT -> { - { encoder.encodeInt(value.intValueExact()) } - } - - Schema.Type.LONG -> { - { encoder.encodeLong(value.longValueExact()) } - } - - Schema.Type.FLOAT -> { - { encoder.encodeFloat(value.toFloat()) } - } - - Schema.Type.DOUBLE -> { - { encoder.encodeDouble(value.toDouble()) } - } - - else -> null + when (currentWriterSchema.type) { + Schema.Type.STRING -> encodeString(value.toString()) + Schema.Type.INT -> encodeInt(value.intValueExact()) + Schema.Type.LONG -> encodeLong(value.longValueExact()) + Schema.Type.FLOAT -> encodeFloat(value.toFloat()) + Schema.Type.DOUBLE -> encodeDouble(value.toDouble()) + else -> throw unsupportedWriterTypeError(Schema.Type.STRING, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE) } } } @@ -202,11 +182,29 @@ public object BigDecimalSerializer : AvroSerializer(BigDecimal::clas encoder: AvroEncoder, value: BigDecimal, ) { - encoder.encodeResolving({ - with(encoder) { - BadEncodedValueError( - value, - encoder.currentWriterSchema, + with(encoder) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectLogicalTypeFromUnion(converter.logicalTypeName, Schema.Type.BYTES, Schema.Type.FIXED) || + trySelectTypeFromUnion(Schema.Type.STRING, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE) || + throw typeNotFoundInUnionError( + Schema.Type.BYTES, + Schema.Type.FIXED, + Schema.Type.STRING, + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE + ) + } + when (currentWriterSchema.type) { + Schema.Type.BYTES -> encodeBytes(converter.toBytes(value, currentWriterSchema, currentWriterSchema.logicalType).array()) + Schema.Type.FIXED -> encodeFixed(converter.toFixed(value, currentWriterSchema, currentWriterSchema.logicalType).bytes()) + Schema.Type.STRING -> encodeString(value.toString()) + Schema.Type.INT -> encodeInt(value.intValueExact()) + Schema.Type.LONG -> encodeLong(value.longValueExact()) + Schema.Type.FLOAT -> encodeFloat(value.toFloat()) + Schema.Type.DOUBLE -> encodeDouble(value.toDouble()) + else -> throw unsupportedWriterTypeError( Schema.Type.BYTES, Schema.Type.FIXED, Schema.Type.STRING, @@ -216,48 +214,6 @@ public object BigDecimalSerializer : AvroSerializer(BigDecimal::clas Schema.Type.DOUBLE ) } - }) { schema -> - when (schema.type) { - Schema.Type.BYTES -> - when (schema.logicalType) { - is LogicalTypes.Decimal -> { - { encoder.encodeBytes(converter.toBytes(value, schema, schema.logicalType)) } - } - - else -> null - } - - Schema.Type.FIXED -> - when (schema.logicalType) { - is LogicalTypes.Decimal -> { - { encoder.encodeFixed(converter.toFixed(value, schema, schema.logicalType)) } - } - - else -> null - } - - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } - } - - Schema.Type.INT -> { - { encoder.encodeInt(value.intValueExact()) } - } - - Schema.Type.LONG -> { - { encoder.encodeLong(value.longValueExact()) } - } - - Schema.Type.FLOAT -> { - { encoder.encodeFloat(value.toFloat()) } - } - - Schema.Type.DOUBLE -> { - { encoder.encodeDouble(value.toDouble()) } - } - - else -> null - } } } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaTimeSerializers.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaTimeSerializers.kt index 9d7d711e..4d369a45 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaTimeSerializers.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaTimeSerializers.kt @@ -4,10 +4,13 @@ import com.github.avrokotlin.avro4k.AnyValueDecoder import com.github.avrokotlin.avro4k.AvroDecoder import com.github.avrokotlin.avro4k.AvroEncoder import com.github.avrokotlin.avro4k.decodeResolvingAny -import com.github.avrokotlin.avro4k.encodeResolving -import com.github.avrokotlin.avro4k.internal.BadEncodedValueError import com.github.avrokotlin.avro4k.internal.UnexpectedDecodeSchemaError import com.github.avrokotlin.avro4k.internal.copy +import com.github.avrokotlin.avro4k.logicalTypeMismatchError +import com.github.avrokotlin.avro4k.trySelectSingleNonNullTypeFromUnion +import com.github.avrokotlin.avro4k.trySelectTypeFromUnion +import com.github.avrokotlin.avro4k.typeNotFoundInUnionError +import com.github.avrokotlin.avro4k.unsupportedWriterTypeError import kotlinx.serialization.SerializationException import kotlinx.serialization.encoding.Decoder import kotlinx.serialization.encoding.Encoder @@ -51,36 +54,20 @@ public object LocalDateSerializer : AvroSerializer(LocalDate::class.q encoder: AvroEncoder, value: LocalDate, ) { - encoder.encodeResolving({ - with(encoder) { - BadEncodedValueError(value, encoder.currentWriterSchema, Schema.Type.INT, Schema.Type.LONG) + with(encoder) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.INT, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.INT, Schema.Type.STRING) } - }) { schema -> - when (schema.type) { + when (currentWriterSchema.type) { Schema.Type.INT -> - when (schema.logicalType?.name) { - LOGICAL_TYPE_NAME_DATE, null -> { - { encoder.encodeInt(value.toEpochDay().toInt()) } - } - - else -> null - } - - Schema.Type.LONG -> - when (schema.logicalType) { - // Date is not compatible with LONG, so we require a null logical type to encode the timestamp - null -> { - { encoder.encodeLong(value.toEpochDay()) } - } - - else -> null + when (currentWriterSchema.logicalType?.name) { + LOGICAL_TYPE_NAME_DATE -> encodeInt(value.toEpochDay().toInt()) + else -> throw logicalTypeMismatchError(LOGICAL_TYPE_NAME_DATE, Schema.Type.INT) } - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } - } - - else -> null + Schema.Type.STRING -> encodeString(value.toString()) + else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.STRING) } } } @@ -148,39 +135,25 @@ public object LocalTimeSerializer : AvroSerializer(LocalTime::class.q value: LocalTime, ) { with(encoder) { - encodeResolving({ - BadEncodedValueError(value, encoder.currentWriterSchema, Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) - }) { schema -> - when (schema.type) { - Schema.Type.INT -> - when (schema.logicalType?.name) { - LOGICAL_TYPE_NAME_TIME_MILLIS, null -> { - { encoder.encodeInt(value.toMillisOfDay()) } - } - - else -> null - } - - Schema.Type.LONG -> - when (schema.logicalType?.name) { - // TimeMillis is not compatible with LONG, so we require a null logical type to encode the timestamp - null -> { - { encoder.encodeLong(value.toMillisOfDay().toLong()) } - } - - LOGICAL_TYPE_NAME_TIME_MICROS -> { - { encoder.encodeLong(value.toMicroOfDay()) } - } - - else -> null - } + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.INT -> + when (currentWriterSchema.logicalType?.name) { + LOGICAL_TYPE_NAME_TIME_MILLIS -> encodeInt(value.toMillisOfDay()) + else -> throw logicalTypeMismatchError(LOGICAL_TYPE_NAME_TIME_MILLIS, Schema.Type.INT) + } - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } + Schema.Type.LONG -> + when (currentWriterSchema.logicalType?.name) { + LOGICAL_TYPE_NAME_TIME_MICROS -> encodeLong(value.toMicroOfDay()) + else -> throw logicalTypeMismatchError(LOGICAL_TYPE_NAME_TIME_MICROS, Schema.Type.LONG) } - else -> null - } + Schema.Type.STRING -> encodeString(value.toString()) + else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) } } } @@ -257,30 +230,21 @@ public object LocalDateTimeSerializer : AvroSerializer(LocalDateT encoder: AvroEncoder, value: LocalDateTime, ) { - encoder.encodeResolving({ - with(encoder) { - BadEncodedValueError(value, encoder.currentWriterSchema, Schema.Type.LONG, Schema.Type.STRING) + with(encoder) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.LONG, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.LONG, Schema.Type.STRING) } - }) { - when (it.type) { + when (currentWriterSchema.type) { Schema.Type.LONG -> - when (it.logicalType?.name) { - LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS, null -> { - { encoder.encodeLong(value.toInstant(ZoneOffset.UTC).toEpochMilli()) } - } - - LOGICAL_TYPE_NAME_TIMESTAMP_MICROS -> { - { encoder.encodeLong(value.toInstant(ZoneOffset.UTC).toEpochMicros()) } - } - - else -> null + when (currentWriterSchema.logicalType?.name) { + LOGICAL_TYPE_NAME_TIMESTAMP_MICROS -> encodeLong(value.toInstant(ZoneOffset.UTC).toEpochMicros()) + LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS -> encodeLong(value.toInstant(ZoneOffset.UTC).toEpochMilli()) + else -> throw logicalTypeMismatchError(LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS, Schema.Type.LONG) } - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } - } - - else -> null + Schema.Type.STRING -> encodeString(value.toString()) + else -> throw unsupportedWriterTypeError(Schema.Type.LONG, Schema.Type.STRING) } } } @@ -335,30 +299,21 @@ public object InstantSerializer : AvroSerializer(Instant::class.qualifi encoder: AvroEncoder, value: Instant, ) { - encoder.encodeResolving({ - with(encoder) { - BadEncodedValueError(value, encoder.currentWriterSchema, Schema.Type.LONG, Schema.Type.STRING) + with(encoder) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.LONG, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.LONG, Schema.Type.STRING) } - }) { - when (it.type) { + when (currentWriterSchema.type) { Schema.Type.LONG -> - when (it.logicalType?.name) { - LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS, null -> { - { encoder.encodeLong(value.toEpochMilli()) } - } - - LOGICAL_TYPE_NAME_TIMESTAMP_MICROS -> { - { encoder.encodeLong(value.toEpochMicros()) } - } - - else -> null + when (currentWriterSchema.logicalType?.name) { + LOGICAL_TYPE_NAME_TIMESTAMP_MICROS -> encodeLong(value.toEpochMicros()) + LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS -> encodeLong(value.toEpochMilli()) + else -> throw logicalTypeMismatchError(LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS, Schema.Type.LONG) } - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } - } - - else -> null + Schema.Type.STRING -> encodeString(value.toString()) + else -> throw unsupportedWriterTypeError(Schema.Type.LONG, Schema.Type.STRING) } } } @@ -412,30 +367,21 @@ public object InstantToMicroSerializer : AvroSerializer(Instant::class. encoder: AvroEncoder, value: Instant, ) { - encoder.encodeResolving({ - with(encoder) { - BadEncodedValueError(value, encoder.currentWriterSchema, Schema.Type.LONG, Schema.Type.STRING) + with(encoder) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.LONG, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.LONG, Schema.Type.STRING) } - }) { - when (it.type) { + when (currentWriterSchema.type) { Schema.Type.LONG -> - when (it.logicalType?.name) { - LOGICAL_TYPE_NAME_TIMESTAMP_MICROS, null -> { - { encoder.encodeLong(value.toEpochMicros()) } - } - - LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS -> { - { encoder.encodeLong(value.toEpochMilli()) } - } - - else -> null + when (currentWriterSchema.logicalType?.name) { + LOGICAL_TYPE_NAME_TIMESTAMP_MICROS -> encodeLong(value.toEpochMicros()) + LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS -> encodeLong(value.toEpochMilli()) + else -> throw logicalTypeMismatchError(LOGICAL_TYPE_NAME_TIMESTAMP_MICROS, Schema.Type.LONG) } - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } - } - - else -> null + Schema.Type.STRING -> encodeString(value.toString()) + else -> throw unsupportedWriterTypeError(Schema.Type.LONG, Schema.Type.STRING) } } } From 7a1c8cd7d36c5a363cec06d06474feade403a856 Mon Sep 17 00:00:00 2001 From: Chuckame Date: Sun, 15 Sep 2024 10:06:31 +0200 Subject: [PATCH 03/13] refactor: rework direct decoding for more clear & compact resolving unions --- .../github/avrokotlin/avro4k/AvroDecoder.kt | 11 + .../direct/AbstractAvroDirectDecoder.kt | 413 +++++------------- .../encoder/direct/RecordDirectEncoder.kt | 43 +- .../generic/AbstractAvroGenericEncoder.kt | 2 +- .../encoder/generic/RecordGenericEncoder.kt | 102 +++-- .../avrokotlin/avro4k/internal/helpers.kt | 6 +- 6 files changed, 219 insertions(+), 358 deletions(-) diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroDecoder.kt index 569bf6a4..7305c4eb 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroDecoder.kt @@ -1,6 +1,7 @@ package com.github.avrokotlin.avro4k import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerializationException import kotlinx.serialization.encoding.Decoder import org.apache.avro.Schema import org.apache.avro.generic.GenericFixed @@ -317,4 +318,14 @@ internal inline fun AvroDecoder.findValueDecoder( resolver(schema) } return foundResolver ?: throw error() +} + +internal fun AvroDecoder.unsupportedWriterTypeError( + mainType: Schema.Type, + vararg fallbackTypes: Schema.Type, +): Throwable { + val fallbacksStr = if (fallbackTypes.isNotEmpty()) ", and also not matching to any compatible type (one of ${fallbackTypes.joinToString()})." else "" + return SerializationException( + "Unsupported schema '${currentWriterSchema.fullName}' for decoded type of ${mainType.getName()}$fallbacksStr. Actual schema: $currentWriterSchema" + ) } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt index c8fdde60..31f1b6ff 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt @@ -1,31 +1,14 @@ package com.github.avrokotlin.avro4k.internal.decoder.direct -import com.github.avrokotlin.avro4k.AnyValueDecoder import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.BooleanValueDecoder -import com.github.avrokotlin.avro4k.CharValueDecoder -import com.github.avrokotlin.avro4k.DoubleValueDecoder -import com.github.avrokotlin.avro4k.FloatValueDecoder -import com.github.avrokotlin.avro4k.IntValueDecoder -import com.github.avrokotlin.avro4k.LongValueDecoder import com.github.avrokotlin.avro4k.UnionDecoder -import com.github.avrokotlin.avro4k.decodeResolvingAny -import com.github.avrokotlin.avro4k.decodeResolvingBoolean -import com.github.avrokotlin.avro4k.decodeResolvingChar -import com.github.avrokotlin.avro4k.decodeResolvingDouble -import com.github.avrokotlin.avro4k.decodeResolvingFloat -import com.github.avrokotlin.avro4k.decodeResolvingInt -import com.github.avrokotlin.avro4k.decodeResolvingLong import com.github.avrokotlin.avro4k.internal.SerializerLocatorMiddleware -import com.github.avrokotlin.avro4k.internal.UnexpectedDecodeSchemaError import com.github.avrokotlin.avro4k.internal.decoder.AbstractPolymorphicDecoder -import com.github.avrokotlin.avro4k.internal.getElementIndexNullable import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch -import com.github.avrokotlin.avro4k.internal.nonNullSerialName import com.github.avrokotlin.avro4k.internal.toByteExact -import com.github.avrokotlin.avro4k.internal.toFloatExact import com.github.avrokotlin.avro4k.internal.toIntExact import com.github.avrokotlin.avro4k.internal.toShortExact +import com.github.avrokotlin.avro4k.unsupportedWriterTypeError import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.PolymorphicKind @@ -59,38 +42,39 @@ internal abstract class AbstractAvroDirectDecoder( } override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { + decodeAndResolveUnion() + return when (descriptor.kind) { StructureKind.LIST -> - decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.ARRAY) }) { - when (it.type) { - Schema.Type.ARRAY -> { - AnyValueDecoder { ArrayBlockDirectDecoder(it, decodeFirstBlock = decodedCollectionSize == -1, { decodedCollectionSize = it }, avro, binaryDecoder) } - } - - else -> null - } + when (currentWriterSchema.type) { + Schema.Type.ARRAY -> + ArrayBlockDirectDecoder( + currentWriterSchema, + decodeFirstBlock = decodedCollectionSize == -1, + { decodedCollectionSize = it }, + avro, + binaryDecoder + ) + else -> throw unsupportedWriterTypeError(Schema.Type.ARRAY) } StructureKind.MAP -> - decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.MAP) }) { - when (it.type) { - Schema.Type.MAP -> { - AnyValueDecoder { MapBlockDirectDecoder(it, decodeFirstBlock = decodedCollectionSize == -1, { decodedCollectionSize = it }, avro, binaryDecoder) } - } - - else -> null - } + when (currentWriterSchema.type) { + Schema.Type.MAP -> + MapBlockDirectDecoder( + currentWriterSchema, + decodeFirstBlock = decodedCollectionSize == -1, + { decodedCollectionSize = it }, + avro, + binaryDecoder + ) + else -> throw unsupportedWriterTypeError(Schema.Type.MAP) } StructureKind.CLASS, StructureKind.OBJECT -> - decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.RECORD) }) { - when (it.type) { - Schema.Type.RECORD -> { - AnyValueDecoder { RecordDirectDecoder(it, descriptor, avro, binaryDecoder) } - } - - else -> null - } + when (currentWriterSchema.type) { + Schema.Type.RECORD -> RecordDirectDecoder(currentWriterSchema, descriptor, avro, binaryDecoder) + else -> throw unsupportedWriterTypeError(Schema.Type.RECORD) } is PolymorphicKind -> PolymorphicDecoder(avro, descriptor, currentWriterSchema, binaryDecoder) @@ -106,46 +90,27 @@ internal abstract class AbstractAvroDirectDecoder( override fun decodeNotNullMark(): Boolean { decodeAndResolveUnion() + return currentWriterSchema.type != Schema.Type.NULL } override fun decodeNull(): Nothing? { - decodeResolvingAny({ - UnexpectedDecodeSchemaError( - "null", - Schema.Type.NULL - ) - }) { - when (it.type) { - Schema.Type.NULL -> { - AnyValueDecoder { binaryDecoder.readNull() } - } + decodeAndResolveUnion() - else -> null - } + if (currentWriterSchema.type != Schema.Type.NULL) { + throw unsupportedWriterTypeError(Schema.Type.NULL) } + binaryDecoder.readNull() return null } override fun decodeBoolean(): Boolean { - return decodeResolvingBoolean({ - UnexpectedDecodeSchemaError( - "boolean", - Schema.Type.BOOLEAN, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.BOOLEAN -> { - BooleanValueDecoder { binaryDecoder.readBoolean() } - } - - Schema.Type.STRING -> { - BooleanValueDecoder { binaryDecoder.readString().toBooleanStrict() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.BOOLEAN -> binaryDecoder.readBoolean() + Schema.Type.STRING -> binaryDecoder.readString().toBooleanStrict() + else -> throw unsupportedWriterTypeError(Schema.Type.BOOLEAN, Schema.Type.STRING) } } @@ -158,284 +123,124 @@ internal abstract class AbstractAvroDirectDecoder( } override fun decodeInt(): Int { - return decodeResolvingInt({ - UnexpectedDecodeSchemaError( - "int", - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - IntValueDecoder { binaryDecoder.readInt() } - } - - Schema.Type.LONG -> { - IntValueDecoder { binaryDecoder.readLong().toIntExact() } - } - - Schema.Type.FLOAT -> { - IntValueDecoder { binaryDecoder.readDouble().toInt() } - } - - Schema.Type.DOUBLE -> { - IntValueDecoder { binaryDecoder.readDouble().toInt() } - } - - Schema.Type.STRING -> { - IntValueDecoder { binaryDecoder.readString().toInt() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt() + Schema.Type.LONG -> binaryDecoder.readLong().toIntExact() + Schema.Type.STRING -> binaryDecoder.readString().toInt() + else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) } } override fun decodeLong(): Long { - return decodeResolvingLong({ - UnexpectedDecodeSchemaError( - "long", - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - LongValueDecoder { binaryDecoder.readInt().toLong() } - } - - Schema.Type.LONG -> { - LongValueDecoder { binaryDecoder.readLong() } - } - - Schema.Type.FLOAT -> { - LongValueDecoder { binaryDecoder.readFloat().toLong() } - } - - Schema.Type.DOUBLE -> { - LongValueDecoder { binaryDecoder.readDouble().toLong() } - } - - Schema.Type.STRING -> { - LongValueDecoder { binaryDecoder.readString().toLong() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt().toLong() + Schema.Type.LONG -> binaryDecoder.readLong() + Schema.Type.STRING -> binaryDecoder.readString().toLong() + else -> throw unsupportedWriterTypeError(Schema.Type.LONG, Schema.Type.INT, Schema.Type.STRING) } } override fun decodeFloat(): Float { - return decodeResolvingFloat({ - UnexpectedDecodeSchemaError( - "float", - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - FloatValueDecoder { binaryDecoder.readInt().toFloat() } - } - - Schema.Type.LONG -> { - FloatValueDecoder { binaryDecoder.readLong().toFloat() } - } - - Schema.Type.FLOAT -> { - FloatValueDecoder { binaryDecoder.readFloat() } - } - - Schema.Type.DOUBLE -> { - FloatValueDecoder { binaryDecoder.readDouble().toFloatExact() } - } - - Schema.Type.STRING -> { - FloatValueDecoder { binaryDecoder.readString().toFloat() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt().toFloat() + Schema.Type.LONG -> binaryDecoder.readLong().toFloat() + Schema.Type.FLOAT -> binaryDecoder.readFloat() + Schema.Type.STRING -> binaryDecoder.readString().toFloat() + else -> throw unsupportedWriterTypeError(Schema.Type.FLOAT, Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) } } override fun decodeDouble(): Double { - return decodeResolvingDouble({ - UnexpectedDecodeSchemaError( - "double", - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - DoubleValueDecoder { binaryDecoder.readInt().toDouble() } - } - - Schema.Type.LONG -> { - DoubleValueDecoder { binaryDecoder.readLong().toDouble() } - } - - Schema.Type.FLOAT -> { - DoubleValueDecoder { binaryDecoder.readFloat().toDouble() } - } - - Schema.Type.DOUBLE -> { - DoubleValueDecoder { binaryDecoder.readDouble() } - } - - Schema.Type.STRING -> { - DoubleValueDecoder { binaryDecoder.readString().toDouble() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt().toDouble() + Schema.Type.LONG -> binaryDecoder.readLong().toDouble() + Schema.Type.FLOAT -> binaryDecoder.readFloat().toDouble() + Schema.Type.DOUBLE -> binaryDecoder.readDouble() + Schema.Type.STRING -> binaryDecoder.readString().toDouble() + else -> throw unsupportedWriterTypeError(Schema.Type.DOUBLE, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.STRING) } } override fun decodeChar(): Char { - return decodeResolvingChar({ - UnexpectedDecodeSchemaError( - "char", - Schema.Type.INT, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - CharValueDecoder { binaryDecoder.readInt().toChar() } - } - - Schema.Type.STRING -> { - CharValueDecoder { binaryDecoder.readString(null).single() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt().toChar() + Schema.Type.STRING -> binaryDecoder.readString(null).single() + else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.STRING) } } override fun decodeString(): String { - return decodeResolvingAny({ - UnexpectedDecodeSchemaError( - "string", - Schema.Type.STRING, - Schema.Type.BYTES, - Schema.Type.FIXED - ) - }) { - when (it.type) { - Schema.Type.STRING, - Schema.Type.BYTES, - -> { - AnyValueDecoder { binaryDecoder.readString() } - } - - Schema.Type.FIXED -> { - AnyValueDecoder { ByteArray(it.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }.decodeToString() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.STRING -> binaryDecoder.readString(null).toString() + Schema.Type.BYTES -> binaryDecoder.readBytes(null).array().decodeToString() + Schema.Type.FIXED -> ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }.decodeToString() + else -> throw unsupportedWriterTypeError(Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED) } } override fun decodeEnum(enumDescriptor: SerialDescriptor): Int { - return decodeResolvingInt({ - UnexpectedDecodeSchemaError( - enumDescriptor.nonNullSerialName, - Schema.Type.ENUM, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.ENUM -> - if (it.isFullNameOrAliasMatch(enumDescriptor)) { - IntValueDecoder { - val enumName = it.enumSymbols[binaryDecoder.readEnum()] - enumDescriptor.getElementIndexNullable(enumName) - ?: avro.enumResolver.getDefaultValueIndex(enumDescriptor) - ?: throw SerializationException( - "Unknown enum symbol name '$enumName' for Enum '${enumDescriptor.serialName}' for writer schema $currentWriterSchema" - ) - } - } else { - null - } + decodeAndResolveUnion() - Schema.Type.STRING -> { - IntValueDecoder { - val enumSymbol = binaryDecoder.readString() - enumDescriptor.getElementIndex(enumSymbol) - .takeIf { index -> index >= 0 } - ?: avro.enumResolver.getDefaultValueIndex(enumDescriptor) - ?: throw SerializationException("Unknown enum symbol '$enumSymbol' for Enum '${enumDescriptor.serialName}'") + return when (currentWriterSchema.type) { + Schema.Type.ENUM -> + if (currentWriterSchema.isFullNameOrAliasMatch(enumDescriptor)) { + val enumName = currentWriterSchema.enumSymbols[binaryDecoder.readEnum()] + val idx = enumDescriptor.getElementIndex(enumName) + if (idx >= 0) { + idx + } else { + avro.enumResolver.getDefaultValueIndex(enumDescriptor) + ?: throw SerializationException("Unknown enum symbol name '$enumName' for Enum '${enumDescriptor.serialName}' for writer schema $currentWriterSchema") } + } else { + throw unsupportedWriterTypeError(Schema.Type.ENUM, Schema.Type.STRING) } - else -> null + Schema.Type.STRING -> { + val enumSymbol = binaryDecoder.readString() + val idx = enumDescriptor.getElementIndex(enumSymbol) + if (idx >= 0) { + idx + } else { + avro.enumResolver.getDefaultValueIndex(enumDescriptor) + ?: throw SerializationException("Unknown enum symbol '$enumSymbol' for Enum '${enumDescriptor.serialName}' for writer schema $currentWriterSchema") + } } + + else -> throw unsupportedWriterTypeError(Schema.Type.ENUM, Schema.Type.STRING) } } override fun decodeBytes(): ByteArray { - return decodeResolvingAny({ - UnexpectedDecodeSchemaError( - "ByteArray", - Schema.Type.BYTES, - Schema.Type.FIXED, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.BYTES -> { - AnyValueDecoder { binaryDecoder.readBytes(null).array() } - } - - Schema.Type.FIXED -> { - AnyValueDecoder { ByteArray(it.fixedSize).also { buf -> binaryDecoder.readFixed(buf) } } - } - - Schema.Type.STRING -> { - AnyValueDecoder { binaryDecoder.readString(null).bytes } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.BYTES -> binaryDecoder.readBytes(null).array() + Schema.Type.FIXED -> ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) } + Schema.Type.STRING -> binaryDecoder.readString(null).bytes + else -> throw unsupportedWriterTypeError(Schema.Type.BYTES, Schema.Type.FIXED, Schema.Type.STRING) } } override fun decodeFixed(): GenericFixed { - return decodeResolvingAny({ - UnexpectedDecodeSchemaError( - "GenericFixed", - Schema.Type.BYTES, - Schema.Type.FIXED - ) - }) { - when (it.type) { - Schema.Type.BYTES -> { - AnyValueDecoder { GenericData.Fixed(it, binaryDecoder.readBytes(null).array()) } - } - - Schema.Type.FIXED -> { - AnyValueDecoder { GenericData.Fixed(it, ByteArray(it.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }) } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.BYTES -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readBytes(null).array()) + Schema.Type.FIXED -> GenericData.Fixed(currentWriterSchema, ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }) + else -> throw unsupportedWriterTypeError(Schema.Type.BYTES, Schema.Type.FIXED) } } } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt index 5c1f67e7..086830ac 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt @@ -15,27 +15,29 @@ internal fun RecordDirectEncoder( avro: Avro, binaryEncoder: org.apache.avro.io.Encoder, ): CompositeEncoder { - val encodingWorkflow = avro.recordResolver.resolveFields(schema, descriptor).encoding - when (encodingWorkflow) { - is EncodingWorkflow.ExactMatch -> return RecordExactDirectEncoder(schema, avro, binaryEncoder) - is EncodingWorkflow.ContiguousWithSkips -> return RecordSkippingDirectEncoder(encodingWorkflow.fieldsToSkip, schema, avro, binaryEncoder) - is EncodingWorkflow.NonContiguous -> return ReorderingCompositeEncoder( - schema.fields.size, - RecordNonContiguousDirectEncoder( - encodingWorkflow.descriptorToWriterFieldIndex, - schema, - avro, - binaryEncoder - ) - ) { _, index -> - encodingWorkflow.descriptorToWriterFieldIndex[index] - } + return when (val encodingWorkflow = avro.recordResolver.resolveFields(schema, descriptor).encoding) { + is EncodingWorkflow.ExactMatch -> RecordContiguousExactEncoder(schema, avro, binaryEncoder) + is EncodingWorkflow.ContiguousWithSkips -> RecordContiguousSkippingEncoder(encodingWorkflow.fieldsToSkip, schema, avro, binaryEncoder) + is EncodingWorkflow.NonContiguous -> + ReorderingCompositeEncoder( + schema.fields.size, + RecordNonContiguousEncoder( + encodingWorkflow.descriptorToWriterFieldIndex, + schema, + avro, + binaryEncoder + ) + ) { _, index -> + encodingWorkflow.descriptorToWriterFieldIndex[index] + } - is EncodingWorkflow.MissingWriterFields -> throw SerializationException("Invalid encoding workflow") + is EncodingWorkflow.MissingWriterFields -> throw SerializationException( + "Missing writer fields ${schema.fields.filter { it.pos() in encodingWorkflow.missingWriterFields }}} from the descriptor $descriptor" + ) } } -private class RecordNonContiguousDirectEncoder( +private class RecordNonContiguousEncoder( private val descriptorToWriterFieldIndex: IntArray, private val schema: Schema, avro: Avro, @@ -57,12 +59,13 @@ private class RecordNonContiguousDirectEncoder( } } -private class RecordSkippingDirectEncoder( +private class RecordContiguousSkippingEncoder( private val skippedElements: BooleanArray, private val schema: Schema, avro: Avro, binaryEncoder: org.apache.avro.io.Encoder, ) : AbstractAvroDirectEncoder(avro, binaryEncoder) { + private var nextWriterFieldIndex = 0 override lateinit var currentWriterSchema: Schema override fun encodeElement( @@ -73,12 +76,12 @@ private class RecordSkippingDirectEncoder( return false } super.encodeElement(descriptor, index) - currentWriterSchema = schema.fields[index].schema() + currentWriterSchema = schema.fields[nextWriterFieldIndex++].schema() return true } } -private class RecordExactDirectEncoder( +private class RecordContiguousExactEncoder( private val schema: Schema, avro: Avro, binaryEncoder: org.apache.avro.io.Encoder, diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt index 06a5f416..a10eab41 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt @@ -18,7 +18,7 @@ internal abstract class AbstractAvroGenericEncoder : AbstractAvroEncoder() { get() = avro.serializersModule override fun getRecordEncoder(descriptor: SerialDescriptor): CompositeEncoder { - return RecordGenericEncoder(avro, descriptor, currentWriterSchema) { encodeValue(it) } + return RecordGenericEncoder(descriptor, currentWriterSchema, avro) { encodeValue(it) } } override fun getPolymorphicEncoder(descriptor: SerialDescriptor): CompositeEncoder { diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt index 5f7706ad..9d44d596 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt @@ -5,52 +5,98 @@ import com.github.avrokotlin.avro4k.ListRecord import com.github.avrokotlin.avro4k.internal.EncodingWorkflow import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.encoding.CompositeEncoder import org.apache.avro.Schema import org.apache.avro.generic.GenericRecord -internal class RecordGenericEncoder( - override val avro: Avro, +@Suppress("FunctionName") +internal fun RecordGenericEncoder( descriptor: SerialDescriptor, - private val schema: Schema, - private val onEncoded: (GenericRecord) -> Unit, -) : AbstractAvroGenericEncoder() { - private val fieldValues: Array = Array(schema.fields.size) { null } + schema: Schema, + avro: Avro, + onEncoded: (GenericRecord) -> Unit, +): CompositeEncoder { + return when (val encodingWorkflow = avro.recordResolver.resolveFields(schema, descriptor).encoding) { + is EncodingWorkflow.ExactMatch -> RecordContiguousExactEncoder(schema, avro, onEncoded) + is EncodingWorkflow.ContiguousWithSkips -> RecordContiguousSkippingEncoder(encodingWorkflow.fieldsToSkip, schema, avro, onEncoded) + is EncodingWorkflow.NonContiguous -> RecordNonContiguousEncoder(encodingWorkflow.descriptorToWriterFieldIndex, schema, avro, onEncoded) + is EncodingWorkflow.MissingWriterFields -> throw SerializationException( + "Missing writer fields ${schema.fields.filter { it.pos() in encodingWorkflow.missingWriterFields }}} from the descriptor $descriptor" + ) + } +} - private val encodingWorkflow = avro.recordResolver.resolveFields(schema, descriptor).encoding - private lateinit var currentField: Schema.Field +private class RecordNonContiguousEncoder( + private val descriptorToWriterFieldIndex: IntArray, + schema: Schema, + avro: Avro, + onEncoded: (GenericRecord) -> Unit, +) : AbstractRecordGenericEncoder(avro, schema, onEncoded) { + override fun encodeElement( + descriptor: SerialDescriptor, + index: Int, + ): Boolean { + val writerFieldIndex = descriptorToWriterFieldIndex[index] + if (writerFieldIndex == -1) { + return false + } + super.encodeElement(descriptor, index) + setWriterField(writerFieldIndex) + return true + } +} - override lateinit var currentWriterSchema: Schema +private class RecordContiguousSkippingEncoder( + private val skippedElements: BooleanArray, + schema: Schema, + avro: Avro, + onEncoded: (GenericRecord) -> Unit, +) : AbstractRecordGenericEncoder(avro, schema, onEncoded) { + private var nextWriterFieldIndex = 0 + + override fun encodeElement( + descriptor: SerialDescriptor, + index: Int, + ): Boolean { + if (skippedElements[index]) { + return false + } + super.encodeElement(descriptor, index) + setWriterField(nextWriterFieldIndex++) + return true + } +} +private class RecordContiguousExactEncoder( + schema: Schema, + avro: Avro, + onEncoded: (GenericRecord) -> Unit, +) : AbstractRecordGenericEncoder(avro, schema, onEncoded) { override fun encodeElement( descriptor: SerialDescriptor, index: Int, ): Boolean { super.encodeElement(descriptor, index) - val writerFieldIndex = - when (encodingWorkflow) { - EncodingWorkflow.ExactMatch -> index + setWriterField(index) + return true + } +} - is EncodingWorkflow.ContiguousWithSkips -> { - if (encodingWorkflow.fieldsToSkip[index]) { - return false - } - index - } +private abstract class AbstractRecordGenericEncoder( + override val avro: Avro, + private val schema: Schema, + private val onEncoded: (GenericRecord) -> Unit, +) : AbstractAvroGenericEncoder() { + private val fieldValues: Array = Array(schema.fields.size) { null } - is EncodingWorkflow.NonContiguous -> { - val writerFieldIndex = encodingWorkflow.descriptorToWriterFieldIndex[index] - if (writerFieldIndex == -1) { - return false - } - writerFieldIndex - } + private lateinit var currentField: Schema.Field + + override lateinit var currentWriterSchema: Schema - is EncodingWorkflow.MissingWriterFields -> throw SerializationException("Invalid encoding workflow") - } + protected fun setWriterField(writerFieldIndex: Int) { val field = schema.fields[writerFieldIndex] currentField = field currentWriterSchema = field.schema() - return true } override fun endStructure(descriptor: SerialDescriptor) { diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt index 079cc23a..02f17b10 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt @@ -183,8 +183,4 @@ internal val AvroProp.jsonNode: JsonNode return objectMapper.readTree(value) } return TextNode.valueOf(value) - } - -internal fun SerialDescriptor.getElementIndexNullable(name: String): Int? { - return getElementIndex(name).takeIf { it >= 0 } -} \ No newline at end of file + } \ No newline at end of file From 068c8627421e762173f19203707d3a1f53c50b06 Mon Sep 17 00:00:00 2001 From: Chuckame Date: Tue, 17 Sep 2024 11:08:01 +0200 Subject: [PATCH 04/13] tests: improve test coverage --- .../avro4k/encoding/PrimitiveEncodingTest.kt | 29 +++++++++++++ .../avro4k/encoding/RecordEncodingTest.kt | 42 +++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt index d98295ca..2b8f7eda 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt @@ -1,6 +1,7 @@ package com.github.avrokotlin.avro4k.encoding import com.github.avrokotlin.avro4k.AvroAssertions +import com.github.avrokotlin.avro4k.SomeEnum import com.github.avrokotlin.avro4k.WrappedBoolean import com.github.avrokotlin.avro4k.WrappedByte import com.github.avrokotlin.avro4k.WrappedChar @@ -10,9 +11,13 @@ import com.github.avrokotlin.avro4k.WrappedInt import com.github.avrokotlin.avro4k.WrappedLong import com.github.avrokotlin.avro4k.WrappedShort import com.github.avrokotlin.avro4k.WrappedString +import com.github.avrokotlin.avro4k.internal.nullable import com.github.avrokotlin.avro4k.record import io.kotest.core.spec.style.StringSpec +import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.Serializable +import kotlinx.serialization.serializer +import org.apache.avro.Schema import java.nio.ByteBuffer internal class PrimitiveEncodingTest : StringSpec({ @@ -31,6 +36,30 @@ internal class PrimitiveEncodingTest : StringSpec({ .isEncodedAs(false) } + @OptIn(InternalSerializationApi::class) + listOf( + true, + false, + 1.toByte(), + 2.toShort(), + 3, + 4L, + 5.0F, + 6.0, + 'A', + SomeEnum.B + ).forEach { + "coerce ${it::class.simpleName} $it to string" { + AvroAssertions.assertThat(it, it::class.serializer()) + .isEncodedAs(it.toString(), writerSchema = Schema.create(Schema.Type.STRING)) + } + + "coerce ${it::class.simpleName} $it to nullable string" { + AvroAssertions.assertThat(it, it::class.serializer()) + .isEncodedAs(it.toString(), writerSchema = Schema.create(Schema.Type.STRING).nullable) + } + } + "read write out bytes" { AvroAssertions.assertThat(ByteTest(3)) .isEncodedAs(record(3)) diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/RecordEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/RecordEncodingTest.kt index f07816b2..d80241f4 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/RecordEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/RecordEncodingTest.kt @@ -183,6 +183,48 @@ internal class RecordEncodingTest : StringSpec({ AvroAssertions.assertThat(input) .isDecodedAs(MissingFields(true)) } + "support decoding from a writer schema with missing descriptor fields (just skipping, no reordering)" { + @Serializable + @SerialName("TheClass") + data class TheClass( + val a: String?, + val b: Boolean?, + val c: Int, + ) + + @Serializable + @SerialName("TheClass") + data class TheLightClass( + val b: Boolean?, + ) + + val writerSchema = + SchemaBuilder.record("TheClass").fields() + .name("a").type(Schema.create(Schema.Type.STRING).nullable).withDefault(null) + .name("b").type().booleanType().noDefault() + .name("c").type().intType().intDefault(42) + .endRecord() + + AvroAssertions.assertThat(TheClass("hello", true, 42)) + .isDecodedAs(TheLightClass(true), writerSchema = writerSchema) + } + "support encoding & decoding with additional descriptor optional fields (no reordering)" { + @Serializable + @SerialName("TheClass") + data class TheClass( + val a: String? = null, + val b: Boolean?, + val c: Int = 42, + ) + + val writerSchema = + SchemaBuilder.record("TheClass").fields() + .name("b").type().booleanType().noDefault() + .endRecord() + + AvroAssertions.assertThat(TheClass("hello", true, 17)) + .isEncodedAs(record(true), expectedDecodedValue = TheClass(null, true, 42), writerSchema = writerSchema) + } "should fail when trying to write a data class but missing the last schema field" { @Serializable @SerialName("Base") From 0fe6920d0fe0dad09ecfd2a33947d062734c2176 Mon Sep 17 00:00:00 2001 From: Chuckame Date: Wed, 25 Sep 2024 23:23:14 +0200 Subject: [PATCH 05/13] tests: improve test coverage for scalar types --- .../direct/AbstractAvroDirectDecoder.kt | 51 ++++++----- .../avrokotlin/avro4k/AvroAssertions.kt | 85 ++++++++++++++++++ .../avro4k/encoding/EnumEncodingTest.kt | 87 +++++++------------ 3 files changed, 146 insertions(+), 77 deletions(-) diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt index 31f1b6ff..3b055ee0 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt @@ -184,8 +184,8 @@ internal abstract class AbstractAvroDirectDecoder( return when (currentWriterSchema.type) { Schema.Type.STRING -> binaryDecoder.readString(null).toString() - Schema.Type.BYTES -> binaryDecoder.readBytes(null).array().decodeToString() - Schema.Type.FIXED -> ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }.decodeToString() + Schema.Type.BYTES -> binaryDecoder.readBytes().decodeToString() + Schema.Type.FIXED -> binaryDecoder.readFixedBytes(currentWriterSchema.fixedSize).decodeToString() else -> throw unsupportedWriterTypeError(Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED) } } @@ -194,41 +194,40 @@ internal abstract class AbstractAvroDirectDecoder( decodeAndResolveUnion() return when (currentWriterSchema.type) { - Schema.Type.ENUM -> + Schema.Type.ENUM -> { if (currentWriterSchema.isFullNameOrAliasMatch(enumDescriptor)) { - val enumName = currentWriterSchema.enumSymbols[binaryDecoder.readEnum()] - val idx = enumDescriptor.getElementIndex(enumName) - if (idx >= 0) { - idx - } else { - avro.enumResolver.getDefaultValueIndex(enumDescriptor) - ?: throw SerializationException("Unknown enum symbol name '$enumName' for Enum '${enumDescriptor.serialName}' for writer schema $currentWriterSchema") - } + val enumSymbol = currentWriterSchema.enumSymbols[binaryDecoder.readEnum()] + enumDescriptor.getEnumIndex(enumSymbol) } else { throw unsupportedWriterTypeError(Schema.Type.ENUM, Schema.Type.STRING) } + } Schema.Type.STRING -> { val enumSymbol = binaryDecoder.readString() - val idx = enumDescriptor.getElementIndex(enumSymbol) - if (idx >= 0) { - idx - } else { - avro.enumResolver.getDefaultValueIndex(enumDescriptor) - ?: throw SerializationException("Unknown enum symbol '$enumSymbol' for Enum '${enumDescriptor.serialName}' for writer schema $currentWriterSchema") - } + enumDescriptor.getEnumIndex(enumSymbol) } else -> throw unsupportedWriterTypeError(Schema.Type.ENUM, Schema.Type.STRING) } } + private fun SerialDescriptor.getEnumIndex(enumName: String): Int { + val idx = getElementIndex(enumName) + return if (idx >= 0) { + idx + } else { + avro.enumResolver.getDefaultValueIndex(this) + ?: throw SerializationException("Unknown enum symbol name '$enumName' for Enum '${this.serialName}' for writer schema $currentWriterSchema") + } + } + override fun decodeBytes(): ByteArray { decodeAndResolveUnion() return when (currentWriterSchema.type) { - Schema.Type.BYTES -> binaryDecoder.readBytes(null).array() - Schema.Type.FIXED -> ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) } + Schema.Type.BYTES -> binaryDecoder.readBytes() + Schema.Type.FIXED -> binaryDecoder.readFixedBytes(currentWriterSchema.fixedSize) Schema.Type.STRING -> binaryDecoder.readString(null).bytes else -> throw unsupportedWriterTypeError(Schema.Type.BYTES, Schema.Type.FIXED, Schema.Type.STRING) } @@ -238,13 +237,21 @@ internal abstract class AbstractAvroDirectDecoder( decodeAndResolveUnion() return when (currentWriterSchema.type) { - Schema.Type.BYTES -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readBytes(null).array()) - Schema.Type.FIXED -> GenericData.Fixed(currentWriterSchema, ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }) + Schema.Type.BYTES -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readBytes()) + Schema.Type.FIXED -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readFixedBytes(currentWriterSchema.fixedSize)) else -> throw unsupportedWriterTypeError(Schema.Type.BYTES, Schema.Type.FIXED) } } } +private fun org.apache.avro.io.Decoder.readFixedBytes(size: Int): ByteArray { + return ByteArray(size).also { buf -> readFixed(buf) } +} + +private fun org.apache.avro.io.Decoder.readBytes(): ByteArray { + return readBytes(null).array() +} + private class PolymorphicDecoder( avro: Avro, descriptor: SerialDescriptor, diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt index 15f47e7e..0648ac3a 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt @@ -1,15 +1,20 @@ package com.github.avrokotlin.avro4k +import com.github.avrokotlin.avro4k.internal.nullable import io.kotest.assertions.Actual import io.kotest.assertions.Expected import io.kotest.assertions.failure import io.kotest.assertions.print.Printed import io.kotest.assertions.withClue +import io.kotest.core.spec.style.scopes.StringSpecRootScope import io.kotest.matchers.shouldBe import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable import kotlinx.serialization.serializer import org.apache.avro.Conversions import org.apache.avro.Schema +import org.apache.avro.SchemaBuilder import org.apache.avro.generic.GenericData import org.apache.avro.generic.GenericDatumReader import org.apache.avro.generic.GenericDatumWriter @@ -202,4 +207,84 @@ internal object AvroAssertions { ): AvroEncodingAssertions { return AvroEncodingAssertions(value, serializer as KSerializer) } +} + +fun encodeToBytesUsingApacheLib( + schema: Schema, + toEncode: Any?, +): ByteArray { + return ByteArrayOutputStream().use { + GenericData.get().createDatumWriter(schema).write(toEncode, EncoderFactory.get().directBinaryEncoder(it, null)) + it.toByteArray() + } +} + +internal inline fun StringSpecRootScope.basicScalarEncodeDecodeTests(value: T, schema: Schema, apacheCompatibleValue: Any? = value) { + "support scalar type ${schema.type} serialization" { + testEncodeDecode(schema, value, apacheCompatibleValue = apacheCompatibleValue) + testEncodeDecode(schema, TestGenericValueClass(value), apacheCompatibleValue = apacheCompatibleValue) + + testEncodeDecode(schema.nullable, value, apacheCompatibleValue = apacheCompatibleValue) + testEncodeDecode(schema.nullable, null) + + testEncodeDecode(schema.nullable, TestGenericValueClass(value), apacheCompatibleValue = apacheCompatibleValue) + testEncodeDecode(schema.nullable, TestGenericValueClass(null), apacheCompatibleValue = null) + testEncodeDecode?>(schema.nullable, null) + } + "scalar type ${schema.type} in record" { + val record = + SchemaBuilder.record("theRecord").fields() + .name("field").type(schema).noDefault() + .endRecord() + + testEncodeDecode(record, TestGenericRecord(value), apacheCompatibleValue = GenericData.Record(record).also { it.put(0, apacheCompatibleValue) }) + testEncodeDecode(record, TestGenericRecord(TestGenericValueClass(value)), apacheCompatibleValue = GenericData.Record(record).also { it.put(0, apacheCompatibleValue) }) + + val recordNullable = + SchemaBuilder.record("theRecord").fields() + .name("field").type(schema.nullable).noDefault() + .endRecord() + testEncodeDecode(recordNullable, TestGenericRecord(value), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, apacheCompatibleValue) }) + testEncodeDecode(recordNullable, TestGenericRecord(null), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, null) }) + testEncodeDecode(recordNullable, TestGenericRecord(TestGenericValueClass(value)), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, apacheCompatibleValue) }) + testEncodeDecode(recordNullable, TestGenericRecord(TestGenericValueClass(null)), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, null) }) + } + "scalar type ${schema.type} in map" { + val map = SchemaBuilder.map().values(schema) + testEncodeDecode(map, mapOf("key" to value), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) + testEncodeDecode(map, mapOf("key" to TestGenericValueClass(value)), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) + + val mapNullable = SchemaBuilder.map().values(schema.nullable) + testEncodeDecode(mapNullable, mapOf("key" to TestGenericValueClass(value)), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) + testEncodeDecode(mapNullable, mapOf("key" to TestGenericValueClass(null)), apacheCompatibleValue = mapOf("key" to null)) + } + "scalar type ${schema.type} in array" { + val array = SchemaBuilder.array().items(schema) + testEncodeDecode(array, listOf(value), apacheCompatibleValue = listOf(apacheCompatibleValue)) + testEncodeDecode(array, listOf(TestGenericValueClass(value)), apacheCompatibleValue = listOf(apacheCompatibleValue)) + + val arrayNullable = SchemaBuilder.array().items(schema.nullable) + testEncodeDecode(arrayNullable, listOf(TestGenericValueClass(value)), apacheCompatibleValue = listOf(apacheCompatibleValue)) + testEncodeDecode(arrayNullable, listOf(TestGenericValueClass(null)), apacheCompatibleValue = listOf(null)) + } +} + +@Serializable +@SerialName("theRecord") +internal data class TestGenericRecord(val field: T) + +@JvmInline +@Serializable +internal value class TestGenericValueClass(val value: T) + +inline fun testEncodeDecode( + schema: Schema, + toEncode: T, + decoded: Any? = toEncode, + apacheCompatibleValue: Any? = toEncode, + serializer: KSerializer = Avro.serializersModule.serializer(), + expectedBytes: ByteArray = encodeToBytesUsingApacheLib(schema, apacheCompatibleValue), +) { + Avro.encodeToByteArray(schema, serializer, toEncode) shouldBe expectedBytes + Avro.decodeFromByteArray(schema, serializer, expectedBytes) shouldBe decoded } \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt index a42ad64f..4858dff9 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt @@ -5,55 +5,26 @@ package com.github.avrokotlin.avro4k.encoding import com.github.avrokotlin.avro4k.Avro import com.github.avrokotlin.avro4k.AvroAssertions import com.github.avrokotlin.avro4k.AvroEnumDefault +import com.github.avrokotlin.avro4k.basicScalarEncodeDecodeTests +import com.github.avrokotlin.avro4k.encodeToByteArray +import com.github.avrokotlin.avro4k.encodeToBytesUsingApacheLib import com.github.avrokotlin.avro4k.record import com.github.avrokotlin.avro4k.schema import com.github.avrokotlin.avro4k.serializer.UUIDSerializer +import io.kotest.assertions.throwables.shouldThrow import io.kotest.core.spec.style.StringSpec import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException import kotlinx.serialization.UseSerializers +import org.apache.avro.SchemaBuilder import org.apache.avro.generic.GenericData internal class EnumEncodingTest : StringSpec({ - "read / write enums" { - AvroAssertions.assertThat(EnumTest(Cream.Bruce, BBM.Moore)) - .isEncodedAs(record(GenericData.EnumSymbol(Avro.schema(), "Bruce"), GenericData.EnumSymbol(Avro.schema(), "Moore"))) + basicScalarEncodeDecodeTests(Cream.Bruce, Avro.schema(), apacheCompatibleValue = GenericData.EnumSymbol(Avro.schema(), "Bruce")) - AvroAssertions.assertThat(Cream.Bruce) - .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "Bruce")) - AvroAssertions.assertThat(CreamValueClass(Cream.Bruce)) - .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "Bruce")) - } - - "read / write list of enums" { - AvroAssertions.assertThat(EnumListTest(listOf(Cream.Bruce, Cream.Clapton))) - .isEncodedAs(record(listOf(GenericData.EnumSymbol(Avro.schema(), "Bruce"), GenericData.EnumSymbol(Avro.schema(), "Clapton")))) - - AvroAssertions.assertThat(listOf(Cream.Bruce, Cream.Clapton)) - .isEncodedAs(listOf(GenericData.EnumSymbol(Avro.schema(), "Bruce"), GenericData.EnumSymbol(Avro.schema(), "Clapton"))) - AvroAssertions.assertThat(listOf(CreamValueClass(Cream.Bruce), CreamValueClass(Cream.Clapton))) - .isEncodedAs(listOf(GenericData.EnumSymbol(Avro.schema(), "Bruce"), GenericData.EnumSymbol(Avro.schema(), "Clapton"))) - } - - "read / write nullable enums" { - AvroAssertions.assertThat(NullableEnumTest(null)) - .isEncodedAs(record(null)) - AvroAssertions.assertThat(NullableEnumTest(Cream.Bruce)) - .isEncodedAs(record(GenericData.EnumSymbol(Avro.schema(), "Bruce"))) - - AvroAssertions.assertThat(Cream.Bruce) - .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "Bruce")) - AvroAssertions.assertThat(null) - .isEncodedAs(null) - - AvroAssertions.assertThat(CreamValueClass(Cream.Bruce)) - .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "Bruce")) - AvroAssertions.assertThat(null) - .isEncodedAs(null) - } - - "Decoding enum with an unknown uses @AvroEnumDefault value" { + "Decoding enum with an unknown symbol uses @AvroEnumDefault value" { AvroAssertions.assertThat(EnumV2WrapperRecord(EnumV2.B)) .isEncodedAs(record(GenericData.EnumSymbol(Avro.schema(), "B"))) .isDecodedAs(EnumV1WrapperRecord(EnumV1.UNKNOWN)) @@ -62,6 +33,23 @@ internal class EnumEncodingTest : StringSpec({ .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "B")) .isDecodedAs(EnumV1.UNKNOWN) } + + "Decoding enum with an unknown symbol fails without @AvroEnumDefault, also ignoring default symbol in writer schema" { + val schema = SchemaBuilder.enumeration("Enum").defaultSymbol("Z").symbols("X", "Z") + + val bytes = encodeToBytesUsingApacheLib(schema, GenericData.EnumSymbol(schema, "X")) + shouldThrow { + Avro.decodeFromByteArray(schema, EnumV1WithoutDefault.serializer(), bytes) + } + } + + "Encoding enum with an unknown symbol fails even with default in writer schema" { + val schema = SchemaBuilder.enumeration("Enum").defaultSymbol("Z").symbols("X", "Z") + + shouldThrow { + Avro.encodeToByteArray(schema, EnumV1WithoutDefault.A) + } + } }) { @Serializable @SerialName("EnumWrapper") @@ -83,6 +71,13 @@ internal class EnumEncodingTest : StringSpec({ A, } + @Serializable + @SerialName("Enum") + private enum class EnumV1WithoutDefault { + UNKNOWN, + A, + } + @Serializable @SerialName("Enum") private enum class EnumV2 { @@ -93,27 +88,9 @@ internal class EnumEncodingTest : StringSpec({ } @Serializable - private data class EnumTest(val a: Cream, val b: BBM) - - @JvmInline - @Serializable - private value class CreamValueClass(val a: Cream) - - @Serializable - private data class EnumListTest(val a: List) - - @Serializable - private data class NullableEnumTest(val a: Cream?) - private enum class Cream { Bruce, Baker, Clapton, } - - private enum class BBM { - Bruce, - Baker, - Moore, - } } \ No newline at end of file From 6428298bb9bea4e267a2825a742bc0d0f76f0a57 Mon Sep 17 00:00:00 2001 From: Chuckame Date: Wed, 25 Sep 2024 23:55:20 +0200 Subject: [PATCH 06/13] tests: improve test coverage for enums --- .editorconfig | 4 + .../avrokotlin/avro4k/AvroAssertions.kt | 127 ++++++++++--- .../avrokotlin/avro4k/dataClassesForTests.kt | 6 +- .../avro4k/encoding/AvroAliasEncodingTest.kt | 40 ---- .../avro4k/encoding/EnumEncodingTest.kt | 96 ---------- .../avrokotlin/avro4k/encoding/EnumTest.kt | 174 ++++++++++++++++++ .../avro4k/schema/EnumSchemaTest.kt | 56 ------ 7 files changed, 282 insertions(+), 221 deletions(-) delete mode 100644 src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt create mode 100644 src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt delete mode 100644 src/test/kotlin/com/github/avrokotlin/avro4k/schema/EnumSchemaTest.kt diff --git a/.editorconfig b/.editorconfig index d24e3874..281f3d13 100644 --- a/.editorconfig +++ b/.editorconfig @@ -23,6 +23,10 @@ ij_editorconfig_spaces_around_assignment_operators = true [{*.kt,*.kts}] ktlint_standard_filename = disabled +ktlint_standard_class-signature = disabled +ktlint_standard_function-signature = disabled +ktlint_standard_chain-method-continuation = disabled +ktlint_standard_function-expression-body = disabled ij_kotlin_align_in_columns_case_branch = false ij_kotlin_align_multiline_binary_operation = false ij_kotlin_align_multiline_extends_list = false diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt index 0648ac3a..8612b41b 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt @@ -9,12 +9,11 @@ import io.kotest.assertions.withClue import io.kotest.core.spec.style.scopes.StringSpecRootScope import io.kotest.matchers.shouldBe import kotlinx.serialization.KSerializer -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable import kotlinx.serialization.serializer import org.apache.avro.Conversions import org.apache.avro.Schema import org.apache.avro.SchemaBuilder +import org.apache.avro.generic.GenericContainer import org.apache.avro.generic.GenericData import org.apache.avro.generic.GenericDatumReader import org.apache.avro.generic.GenericDatumWriter @@ -221,61 +220,133 @@ fun encodeToBytesUsingApacheLib( internal inline fun StringSpecRootScope.basicScalarEncodeDecodeTests(value: T, schema: Schema, apacheCompatibleValue: Any? = value) { "support scalar type ${schema.type} serialization" { + Avro.schema() shouldBe schema testEncodeDecode(schema, value, apacheCompatibleValue = apacheCompatibleValue) - testEncodeDecode(schema, TestGenericValueClass(value), apacheCompatibleValue = apacheCompatibleValue) + Avro.schema>() shouldBe schema + testEncodeDecode(schema, ValueClassWithGenericField(value), apacheCompatibleValue = apacheCompatibleValue) + } + "support scalar type ${schema.type} serialization as nullable" { + Avro.schema() shouldBe schema.nullable testEncodeDecode(schema.nullable, value, apacheCompatibleValue = apacheCompatibleValue) testEncodeDecode(schema.nullable, null) - testEncodeDecode(schema.nullable, TestGenericValueClass(value), apacheCompatibleValue = apacheCompatibleValue) - testEncodeDecode(schema.nullable, TestGenericValueClass(null), apacheCompatibleValue = null) - testEncodeDecode?>(schema.nullable, null) + Avro.schema>() shouldBe schema.nullable + testEncodeDecode(schema.nullable, ValueClassWithGenericField(value), apacheCompatibleValue = apacheCompatibleValue) + testEncodeDecode(schema.nullable, ValueClassWithGenericField(null), apacheCompatibleValue = null) + + Avro.schema?>() shouldBe schema.nullable + testEncodeDecode?>(schema.nullable, null) + + Avro.schema?>() shouldBe schema.nullable + testEncodeDecode?>(schema.nullable, null) } "scalar type ${schema.type} in record" { val record = - SchemaBuilder.record("theRecord").fields() + SchemaBuilder.record("RecordWithGenericField").fields() .name("field").type(schema).noDefault() .endRecord() - - testEncodeDecode(record, TestGenericRecord(value), apacheCompatibleValue = GenericData.Record(record).also { it.put(0, apacheCompatibleValue) }) - testEncodeDecode(record, TestGenericRecord(TestGenericValueClass(value)), apacheCompatibleValue = GenericData.Record(record).also { it.put(0, apacheCompatibleValue) }) + Avro.schema>() shouldBe record + Avro.schema>>() shouldBe record + testEncodeDecode(record, + RecordWithGenericField(value), apacheCompatibleValue = GenericData.Record(record).also { it.put(0, apacheCompatibleValue) }) + testEncodeDecode(record, + RecordWithGenericField(ValueClassWithGenericField(value)), apacheCompatibleValue = GenericData.Record(record).also { it.put(0, apacheCompatibleValue) }) + } + "scalar type ${schema.type} in record as nullable field" { + val expectedRecordSchemaNullable = + SchemaBuilder.record("RecordWithGenericField").fields() + .name("field").type(schema.nullable).withDefault(null) + .endRecord() + Avro.schema>() shouldBe expectedRecordSchemaNullable + Avro.schema>>() shouldBe expectedRecordSchemaNullable + Avro.schema?>>() shouldBe expectedRecordSchemaNullable + Avro.schema?>>() shouldBe expectedRecordSchemaNullable val recordNullable = - SchemaBuilder.record("theRecord").fields() + SchemaBuilder.record("RecordWithGenericField").fields() .name("field").type(schema.nullable).noDefault() .endRecord() - testEncodeDecode(recordNullable, TestGenericRecord(value), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, apacheCompatibleValue) }) - testEncodeDecode(recordNullable, TestGenericRecord(null), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, null) }) - testEncodeDecode(recordNullable, TestGenericRecord(TestGenericValueClass(value)), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, apacheCompatibleValue) }) - testEncodeDecode(recordNullable, TestGenericRecord(TestGenericValueClass(null)), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, null) }) + testEncodeDecode(recordNullable, + RecordWithGenericField(value), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, apacheCompatibleValue) }) + testEncodeDecode(recordNullable, + RecordWithGenericField(null), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, null) }) + testEncodeDecode( + recordNullable, + RecordWithGenericField(ValueClassWithGenericField(value)), + apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, apacheCompatibleValue) } + ) + testEncodeDecode(recordNullable, + RecordWithGenericField(ValueClassWithGenericField(null)), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, null) }) } "scalar type ${schema.type} in map" { val map = SchemaBuilder.map().values(schema) + Avro.schema>() shouldBe map + Avro.schema>>() shouldBe map + Avro.schema>>() shouldBe map + Avro.schema, ValueClassWithGenericField>>() shouldBe map + Avro.schema>() shouldBe map + Avro.schema, T>>() shouldBe map testEncodeDecode(map, mapOf("key" to value), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) - testEncodeDecode(map, mapOf("key" to TestGenericValueClass(value)), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) + testEncodeDecode(map, mapOf("key" to ValueClassWithGenericField(value)), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) val mapNullable = SchemaBuilder.map().values(schema.nullable) - testEncodeDecode(mapNullable, mapOf("key" to TestGenericValueClass(value)), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) - testEncodeDecode(mapNullable, mapOf("key" to TestGenericValueClass(null)), apacheCompatibleValue = mapOf("key" to null)) + Avro.schema>() shouldBe mapNullable + Avro.schema>>() shouldBe mapNullable + Avro.schema?>>() shouldBe mapNullable + Avro.schema?>>() shouldBe mapNullable + Avro.schema>>() shouldBe mapNullable + Avro.schema?>>() shouldBe mapNullable + Avro.schema?>>() shouldBe mapNullable + Avro.schema, ValueClassWithGenericField>>() shouldBe mapNullable + Avro.schema, ValueClassWithGenericField?>>() shouldBe mapNullable + Avro.schema, ValueClassWithGenericField?>>() shouldBe mapNullable + Avro.schema>() shouldBe mapNullable + Avro.schema, T?>>() shouldBe mapNullable + testEncodeDecode(mapNullable, mapOf("key" to ValueClassWithGenericField(value)), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) + testEncodeDecode(mapNullable, mapOf("key" to ValueClassWithGenericField(null)), apacheCompatibleValue = mapOf("key" to null)) } "scalar type ${schema.type} in array" { val array = SchemaBuilder.array().items(schema) + Avro.schema>() shouldBe array + Avro.schema>>() shouldBe array + Avro.schema>() shouldBe array + Avro.schema>>() shouldBe array + Avro.schema>() shouldBe array + Avro.schema>>() shouldBe array testEncodeDecode(array, listOf(value), apacheCompatibleValue = listOf(apacheCompatibleValue)) - testEncodeDecode(array, listOf(TestGenericValueClass(value)), apacheCompatibleValue = listOf(apacheCompatibleValue)) + testEncodeDecode(array, listOf(ValueClassWithGenericField(value)), apacheCompatibleValue = listOf(apacheCompatibleValue)) val arrayNullable = SchemaBuilder.array().items(schema.nullable) - testEncodeDecode(arrayNullable, listOf(TestGenericValueClass(value)), apacheCompatibleValue = listOf(apacheCompatibleValue)) - testEncodeDecode(arrayNullable, listOf(TestGenericValueClass(null)), apacheCompatibleValue = listOf(null)) + Avro.schema>() shouldBe arrayNullable + Avro.schema>>() shouldBe arrayNullable + Avro.schema?>>() shouldBe arrayNullable + Avro.schema?>>() shouldBe arrayNullable + Avro.schema>() shouldBe arrayNullable + Avro.schema?>>() shouldBe arrayNullable + Avro.schema?>>() shouldBe arrayNullable + Avro.schema?>>() shouldBe arrayNullable + Avro.schema>() shouldBe arrayNullable + Avro.schema?>>() shouldBe arrayNullable + Avro.schema?>>() shouldBe arrayNullable + Avro.schema?>>() shouldBe arrayNullable + testEncodeDecode(arrayNullable, listOf(ValueClassWithGenericField(value)), apacheCompatibleValue = listOf(apacheCompatibleValue)) + testEncodeDecode(arrayNullable, listOf(ValueClassWithGenericField(null)), apacheCompatibleValue = listOf(null)) } } -@Serializable -@SerialName("theRecord") -internal data class TestGenericRecord(val field: T) - -@JvmInline -@Serializable -internal value class TestGenericValueClass(val value: T) +internal inline fun StringSpecRootScope.testSerializationTypeCompatibility(logicalValue: T, encodedAsValue: R) { + val schema = when { + encodedAsValue is GenericContainer -> encodedAsValue.schema + else -> Avro.schema() + } + "Support ${logicalValue::class.simpleName} serialization as ${schema.type}" { + testEncodeDecode(schema, logicalValue, apacheCompatibleValue = encodedAsValue) + } + "Support ${logicalValue::class.simpleName} serialization as nullable ${schema.type}" { + testEncodeDecode(schema.nullable, logicalValue, apacheCompatibleValue = encodedAsValue) + } +} inline fun testEncodeDecode( schema: Schema, diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/dataClassesForTests.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/dataClassesForTests.kt index 552a3110..6a6a2d62 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/dataClassesForTests.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/dataClassesForTests.kt @@ -12,7 +12,11 @@ internal enum class SomeEnum { @Serializable @SerialName("RecordWithGenericField") -internal data class RecordWithGenericField(val value: T) +internal data class RecordWithGenericField(val field: T) + +@JvmInline +@Serializable +internal value class ValueClassWithGenericField(val value: T) @Serializable @JvmInline diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroAliasEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroAliasEncodingTest.kt index dad51225..b6a6155e 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroAliasEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroAliasEncodingTest.kt @@ -2,7 +2,6 @@ package com.github.avrokotlin.avro4k.encoding import com.github.avrokotlin.avro4k.AvroAlias import com.github.avrokotlin.avro4k.AvroAssertions -import com.github.avrokotlin.avro4k.SomeEnum import com.github.avrokotlin.avro4k.record import com.github.avrokotlin.avro4k.recordWithSchema import io.kotest.core.spec.style.StringSpec @@ -10,7 +9,6 @@ import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import org.apache.avro.Schema import org.apache.avro.SchemaBuilder -import org.apache.avro.generic.GenericData internal class AvroAliasEncodingTest : StringSpec({ "support alias on field" { @@ -37,38 +35,6 @@ internal class AvroAliasEncodingTest : StringSpec({ .isEncodedAs(recordWithSchema(writerSchema.types[1], "hello"), writerSchema = writerSchema) .isDecodedAs(DecodedRecordWithAlias("hello")) } - - "support alias on enum" { - val writerSchema = - SchemaBuilder.record("EnumWrapperRecord").fields() - .name("value") - .type( - SchemaBuilder.enumeration("UnknownEnum").aliases("com.github.avrokotlin.avro4k.SomeEnum").symbols("A", "B", "C") - ) - .noDefault() - .endRecord() - AvroAssertions.assertThat(EnumWrapperRecord(SomeEnum.A)) - .isEncodedAs(record(GenericData.EnumSymbol(writerSchema.fields[0].schema(), "A")), writerSchema = writerSchema) - } - - "support alias on enum inside an union" { - val writerSchema = - SchemaBuilder.record("EnumWrapperRecord").fields() - .name("value") - .type( - Schema.createUnion( - SchemaBuilder.enumeration("OtherEnum").symbols("OTHER"), - SchemaBuilder.record("UnknownRecord").aliases("RecordA") - .fields().name("field").type().stringType().noDefault() - .endRecord(), - SchemaBuilder.enumeration("UnknownEnum").aliases("com.github.avrokotlin.avro4k.SomeEnum").symbols("A", "B", "C") - ) - ) - .noDefault() - .endRecord() - AvroAssertions.assertThat(EnumWrapperRecord(SomeEnum.A)) - .isEncodedAs(record(GenericData.EnumSymbol(writerSchema.fields[0].schema().types[2], "A")), writerSchema = writerSchema) - } }) { @Serializable @SerialName("Record") @@ -94,10 +60,4 @@ internal class AvroAliasEncodingTest : StringSpec({ private data class DecodedRecordWithAlias( val field: String, ) - - @Serializable - @SerialName("EnumWrapperRecord") - private data class EnumWrapperRecord( - val value: SomeEnum, - ) } \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt deleted file mode 100644 index 4858dff9..00000000 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt +++ /dev/null @@ -1,96 +0,0 @@ -@file:UseSerializers(UUIDSerializer::class) - -package com.github.avrokotlin.avro4k.encoding - -import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.AvroAssertions -import com.github.avrokotlin.avro4k.AvroEnumDefault -import com.github.avrokotlin.avro4k.basicScalarEncodeDecodeTests -import com.github.avrokotlin.avro4k.encodeToByteArray -import com.github.avrokotlin.avro4k.encodeToBytesUsingApacheLib -import com.github.avrokotlin.avro4k.record -import com.github.avrokotlin.avro4k.schema -import com.github.avrokotlin.avro4k.serializer.UUIDSerializer -import io.kotest.assertions.throwables.shouldThrow -import io.kotest.core.spec.style.StringSpec -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable -import kotlinx.serialization.SerializationException -import kotlinx.serialization.UseSerializers -import org.apache.avro.SchemaBuilder -import org.apache.avro.generic.GenericData - -internal class EnumEncodingTest : StringSpec({ - - basicScalarEncodeDecodeTests(Cream.Bruce, Avro.schema(), apacheCompatibleValue = GenericData.EnumSymbol(Avro.schema(), "Bruce")) - - "Decoding enum with an unknown symbol uses @AvroEnumDefault value" { - AvroAssertions.assertThat(EnumV2WrapperRecord(EnumV2.B)) - .isEncodedAs(record(GenericData.EnumSymbol(Avro.schema(), "B"))) - .isDecodedAs(EnumV1WrapperRecord(EnumV1.UNKNOWN)) - - AvroAssertions.assertThat(EnumV2.B) - .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "B")) - .isDecodedAs(EnumV1.UNKNOWN) - } - - "Decoding enum with an unknown symbol fails without @AvroEnumDefault, also ignoring default symbol in writer schema" { - val schema = SchemaBuilder.enumeration("Enum").defaultSymbol("Z").symbols("X", "Z") - - val bytes = encodeToBytesUsingApacheLib(schema, GenericData.EnumSymbol(schema, "X")) - shouldThrow { - Avro.decodeFromByteArray(schema, EnumV1WithoutDefault.serializer(), bytes) - } - } - - "Encoding enum with an unknown symbol fails even with default in writer schema" { - val schema = SchemaBuilder.enumeration("Enum").defaultSymbol("Z").symbols("X", "Z") - - shouldThrow { - Avro.encodeToByteArray(schema, EnumV1WithoutDefault.A) - } - } -}) { - @Serializable - @SerialName("EnumWrapper") - private data class EnumV1WrapperRecord( - val value: EnumV1, - ) - - @Serializable - @SerialName("EnumWrapper") - private data class EnumV2WrapperRecord( - val value: EnumV2, - ) - - @Serializable - @SerialName("Enum") - private enum class EnumV1 { - @AvroEnumDefault - UNKNOWN, - A, - } - - @Serializable - @SerialName("Enum") - private enum class EnumV1WithoutDefault { - UNKNOWN, - A, - } - - @Serializable - @SerialName("Enum") - private enum class EnumV2 { - @AvroEnumDefault - UNKNOWN, - A, - B, - } - - @Serializable - private enum class Cream { - Bruce, - Baker, - Clapton, - } -} \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt new file mode 100644 index 00000000..2280233e --- /dev/null +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt @@ -0,0 +1,174 @@ +@file:UseSerializers(UUIDSerializer::class) + +package com.github.avrokotlin.avro4k.encoding + +import com.github.avrokotlin.avro4k.Avro +import com.github.avrokotlin.avro4k.AvroAlias +import com.github.avrokotlin.avro4k.AvroAssertions +import com.github.avrokotlin.avro4k.AvroDoc +import com.github.avrokotlin.avro4k.AvroEnumDefault +import com.github.avrokotlin.avro4k.RecordWithGenericField +import com.github.avrokotlin.avro4k.SomeEnum +import com.github.avrokotlin.avro4k.ValueClassWithGenericField +import com.github.avrokotlin.avro4k.basicScalarEncodeDecodeTests +import com.github.avrokotlin.avro4k.encodeToByteArray +import com.github.avrokotlin.avro4k.encodeToBytesUsingApacheLib +import com.github.avrokotlin.avro4k.record +import com.github.avrokotlin.avro4k.schema +import com.github.avrokotlin.avro4k.serializer.UUIDSerializer +import com.github.avrokotlin.avro4k.testSerializationTypeCompatibility +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.core.spec.style.StringSpec +import io.kotest.matchers.shouldBe +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException +import kotlinx.serialization.UseSerializers +import org.apache.avro.Schema +import org.apache.avro.SchemaBuilder +import org.apache.avro.generic.GenericData + +internal class EnumTest : StringSpec({ + val expectedEnumSchema = SchemaBuilder.enumeration(Cream::class.qualifiedName).aliases("TheCream").doc("documentation").symbols("Bruce", "Baker", "Clapton") + basicScalarEncodeDecodeTests(Cream.Bruce, expectedEnumSchema, apacheCompatibleValue = GenericData.EnumSymbol(expectedEnumSchema, "Bruce")) + testSerializationTypeCompatibility(Cream.Baker, "Baker") + + // TODO test alias decoding + // TODO test decoding from union (name resolution) + + "Only allow 1 @AvroEnumDefault at max" { + shouldThrow { + Avro.schema() + } + shouldThrow { + Avro.schema>() + } + shouldThrow { + Avro.schema>() + } + } + + "Decoding enum with an unknown symbol uses @AvroEnumDefault value" { + Avro.schema() shouldBe + SchemaBuilder.enumeration("Enum") + .defaultSymbol("UNKNOWN") + .symbols("UNKNOWN", "A", "B") + + AvroAssertions.assertThat(EnumV2WrapperRecord(EnumV2.B)) + .isEncodedAs(record(GenericData.EnumSymbol(Avro.schema(), "B"))) + .isDecodedAs(EnumV1WrapperRecord(EnumV1.UNKNOWN)) + + AvroAssertions.assertThat(EnumV2.B) + .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "B")) + .isDecodedAs(EnumV1.UNKNOWN) + } + + "Decoding enum with an unknown symbol fails without @AvroEnumDefault, also ignoring default symbol in writer schema" { + val schema = SchemaBuilder.enumeration("Enum").defaultSymbol("Z").symbols("X", "Z") + + val bytes = encodeToBytesUsingApacheLib(schema, GenericData.EnumSymbol(schema, "X")) + shouldThrow { + Avro.decodeFromByteArray(schema, EnumV1WithoutDefault.serializer(), bytes) + } + } + + "Encoding enum with an unknown symbol fails even with default in writer schema" { + val schema = SchemaBuilder.enumeration("Enum").defaultSymbol("Z").symbols("X", "Z") + + shouldThrow { + Avro.encodeToByteArray(schema, EnumV1WithoutDefault.A) + } + } + + "support alias on enum" { + val writerSchema = + SchemaBuilder.record("EnumWrapperRecord").fields() + .name("value") + .type( + SchemaBuilder.enumeration("UnknownEnum").aliases("com.github.avrokotlin.avro4k.SomeEnum").symbols("A", "B", "C") + ) + .noDefault() + .endRecord() + AvroAssertions.assertThat(EnumWrapperRecord(SomeEnum.A)) + .isEncodedAs(record(GenericData.EnumSymbol(writerSchema.fields[0].schema(), "A")), writerSchema = writerSchema) + } + + "support alias on enum inside an union" { + val writerSchema = + SchemaBuilder.record("EnumWrapperRecord").fields() + .name("value") + .type( + Schema.createUnion( + SchemaBuilder.enumeration("OtherEnum").symbols("OTHER"), + SchemaBuilder.record("UnknownRecord").aliases("RecordA") + .fields().name("field").type().stringType().noDefault() + .endRecord(), + SchemaBuilder.enumeration("UnknownEnum").aliases("com.github.avrokotlin.avro4k.SomeEnum").symbols("A", "B", "C") + ) + ) + .noDefault() + .endRecord() + AvroAssertions.assertThat(EnumWrapperRecord(SomeEnum.A)) + .isEncodedAs(record(GenericData.EnumSymbol(writerSchema.fields[0].schema().types[2], "A")), writerSchema = writerSchema) + } +}) { + @Serializable + @SerialName("EnumWrapper") + private data class EnumV1WrapperRecord( + val value: EnumV1, + ) + + @Serializable + @SerialName("EnumWrapper") + private data class EnumV2WrapperRecord( + val value: EnumV2, + ) + + @Serializable + @SerialName("EnumWrapperRecord") + private data class EnumWrapperRecord( + val value: SomeEnum, + ) + + @Serializable + @SerialName("Enum") + private enum class EnumV1 { + @AvroEnumDefault + UNKNOWN, + A, + } + + @Serializable + private enum class BadEnumWithManyDefaults { + @AvroEnumDefault + DEF1, + + @AvroEnumDefault + DEF2, + } + + @Serializable + @SerialName("Enum") + private enum class EnumV1WithoutDefault { + UNKNOWN, + A, + } + + @Serializable + @SerialName("Enum") + private enum class EnumV2 { + @AvroEnumDefault + UNKNOWN, + A, + B, + } + + @Serializable + @AvroAlias("TheCream") + @AvroDoc("documentation") + private enum class Cream { + Bruce, + Baker, + Clapton, + } +} \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/EnumSchemaTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/EnumSchemaTest.kt deleted file mode 100644 index 122943dc..00000000 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/EnumSchemaTest.kt +++ /dev/null @@ -1,56 +0,0 @@ -package com.github.avrokotlin.avro4k.schema - -import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.AvroAlias -import com.github.avrokotlin.avro4k.AvroAssertions -import com.github.avrokotlin.avro4k.AvroDoc -import com.github.avrokotlin.avro4k.AvroEnumDefault -import com.github.avrokotlin.avro4k.RecordWithGenericField -import com.github.avrokotlin.avro4k.internal.nullable -import com.github.avrokotlin.avro4k.schema -import io.kotest.assertions.throwables.shouldThrow -import io.kotest.core.spec.style.StringSpec -import kotlinx.serialization.Serializable -import kotlin.io.path.Path - -internal class EnumSchemaTest : StringSpec({ - "should generate schema with alias, enum default and doc" { - AvroAssertions.assertThat() - .generatesSchema(Path("/enum_with_default.json")) - AvroAssertions.assertThat>() - .generatesSchema(Path("/enum_with_default_record.json")) - } - "should generate nullable schema" { - AvroAssertions.assertThat() - .generatesSchema(Path("/enum_with_default.json")) { it.nullable } - } - "fail with unknown values" { - shouldThrow { - Avro.schema() - } - shouldThrow { - Avro.schema>() - } - } -}) { - @Serializable - @AvroAlias("MySuit") - @AvroDoc("documentation") - private enum class Suit { - SPADES, - HEARTS, - - @AvroEnumDefault - DIAMONDS, - CLUBS, - } - - @Serializable - private enum class InvalidEnumDefault { - @AvroEnumDefault - VEGGIE, - - @AvroEnumDefault - MEAT, - } -} \ No newline at end of file From d840390049ada620e7c268c621f0d03a2c91d808 Mon Sep 17 00:00:00 2001 From: Chuckame Date: Wed, 9 Oct 2024 22:52:47 +0200 Subject: [PATCH 07/13] fix: do not coerce between int/long and decimals, also improve enum and scalar tests --- build.gradle.kts | 1 + .../avrokotlin/avro4k/internal/NumberUtils.kt | 34 +-- .../direct/AbstractAvroDirectDecoder.kt | 33 ++- .../internal/encoder/AbstractAvroEncoder.kt | 86 +++--- .../avro4k/internal/schema/MapVisitor.kt | 14 +- .../avrokotlin/avro4k/AvroAssertions.kt | 220 ++++++++++++---- .../avrokotlin/avro4k/dataClassesForTests.kt | 31 ++- .../avrokotlin/avro4k/encoding/EnumTest.kt | 55 ++-- .../avro4k/encoding/PrimitiveEncodingTest.kt | 249 ++++++------------ 9 files changed, 397 insertions(+), 326 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index c05db046..7560f8ab 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -30,6 +30,7 @@ dependencies { testImplementation(libs.kotest.core) testImplementation(libs.kotest.json) testImplementation(libs.kotest.property) + testImplementation(kotlin("reflect")) } kotlin { diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/NumberUtils.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/NumberUtils.kt index 480385eb..f1e67803 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/NumberUtils.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/NumberUtils.kt @@ -3,13 +3,6 @@ package com.github.avrokotlin.avro4k.internal import kotlinx.serialization.SerializationException import java.math.BigDecimal -internal fun BigDecimal.toLongExact(): Long { - if (this.toLong().toBigDecimal() != this) { - throw SerializationException("Value $this is not a valid Long") - } - return this.toLong() -} - internal fun Int.toByteExact(): Byte { if (this.toByte().toInt() != this) { throw SerializationException("Value $this is not a valid Byte") @@ -54,35 +47,44 @@ internal fun BigDecimal.toShortExact(): Short { internal fun Long.toIntExact(): Int { if (this.toInt().toLong() != this) { - throw SerializationException("Value $this is not a valid Int") + throw invalidType() } return this.toInt() } internal fun BigDecimal.toIntExact(): Int { if (this.toInt().toBigDecimal() != this) { - throw SerializationException("Value $this is not a valid Int") + throw invalidType() } return this.toInt() } -internal fun BigDecimal.toFloatExact(): Float { - if (this.toFloat().toBigDecimal() != this) { - throw SerializationException("Value $this is not a valid Float") +internal fun BigDecimal.toLongExact(): Long { + if (this.toLong().toBigDecimal() != this) { + throw invalidType() } - return this.toFloat() + return this.toLong() } internal fun Double.toFloatExact(): Float { if (this.toFloat().toDouble() != this) { - throw SerializationException("Value $this is not a valid Float") + throw invalidType() } return this.toFloat() } internal fun BigDecimal.toDoubleExact(): Double { if (this.toDouble().toBigDecimal() != this) { - throw SerializationException("Value $this is not a valid Double") + throw invalidType() } return this.toDouble() -} \ No newline at end of file +} + +internal fun BigDecimal.toFloatExact(): Float { + if (this.toFloat().toBigDecimal() != this) { + throw invalidType() + } + return this.toFloat() +} + +private inline fun Any.invalidType() = SerializationException("Value $this is not a valid ${T::class.simpleName}") \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt index 3b055ee0..76421907 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt @@ -6,6 +6,7 @@ import com.github.avrokotlin.avro4k.internal.SerializerLocatorMiddleware import com.github.avrokotlin.avro4k.internal.decoder.AbstractPolymorphicDecoder import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch import com.github.avrokotlin.avro4k.internal.toByteExact +import com.github.avrokotlin.avro4k.internal.toFloatExact import com.github.avrokotlin.avro4k.internal.toIntExact import com.github.avrokotlin.avro4k.internal.toShortExact import com.github.avrokotlin.avro4k.unsupportedWriterTypeError @@ -137,8 +138,8 @@ internal abstract class AbstractAvroDirectDecoder( decodeAndResolveUnion() return when (currentWriterSchema.type) { - Schema.Type.INT -> binaryDecoder.readInt().toLong() Schema.Type.LONG -> binaryDecoder.readLong() + Schema.Type.INT -> binaryDecoder.readInt().toLong() Schema.Type.STRING -> binaryDecoder.readString().toLong() else -> throw unsupportedWriterTypeError(Schema.Type.LONG, Schema.Type.INT, Schema.Type.STRING) } @@ -148,11 +149,10 @@ internal abstract class AbstractAvroDirectDecoder( decodeAndResolveUnion() return when (currentWriterSchema.type) { - Schema.Type.INT -> binaryDecoder.readInt().toFloat() - Schema.Type.LONG -> binaryDecoder.readLong().toFloat() Schema.Type.FLOAT -> binaryDecoder.readFloat() + Schema.Type.DOUBLE -> binaryDecoder.readDouble().toFloatExact() Schema.Type.STRING -> binaryDecoder.readString().toFloat() - else -> throw unsupportedWriterTypeError(Schema.Type.FLOAT, Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) + else -> throw unsupportedWriterTypeError(Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) } } @@ -160,12 +160,10 @@ internal abstract class AbstractAvroDirectDecoder( decodeAndResolveUnion() return when (currentWriterSchema.type) { - Schema.Type.INT -> binaryDecoder.readInt().toDouble() - Schema.Type.LONG -> binaryDecoder.readLong().toDouble() Schema.Type.FLOAT -> binaryDecoder.readFloat().toDouble() Schema.Type.DOUBLE -> binaryDecoder.readDouble() Schema.Type.STRING -> binaryDecoder.readString().toDouble() - else -> throw unsupportedWriterTypeError(Schema.Type.DOUBLE, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.STRING) + else -> throw unsupportedWriterTypeError(Schema.Type.DOUBLE, Schema.Type.FLOAT, Schema.Type.STRING) } } @@ -183,10 +181,24 @@ internal abstract class AbstractAvroDirectDecoder( decodeAndResolveUnion() return when (currentWriterSchema.type) { - Schema.Type.STRING -> binaryDecoder.readString(null).toString() + Schema.Type.STRING -> binaryDecoder.readString() Schema.Type.BYTES -> binaryDecoder.readBytes().decodeToString() Schema.Type.FIXED -> binaryDecoder.readFixedBytes(currentWriterSchema.fixedSize).decodeToString() - else -> throw unsupportedWriterTypeError(Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED) + Schema.Type.BOOLEAN -> binaryDecoder.readBoolean().toString() + Schema.Type.INT -> binaryDecoder.readInt().toString() + Schema.Type.LONG -> binaryDecoder.readLong().toString() + Schema.Type.FLOAT -> binaryDecoder.readFloat().toString() + Schema.Type.DOUBLE -> binaryDecoder.readDouble().toString() + else -> throw unsupportedWriterTypeError( + Schema.Type.STRING, + Schema.Type.BYTES, + Schema.Type.FIXED, + Schema.Type.BOOLEAN, + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE + ) } } @@ -239,7 +251,8 @@ internal abstract class AbstractAvroDirectDecoder( return when (currentWriterSchema.type) { Schema.Type.BYTES -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readBytes()) Schema.Type.FIXED -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readFixedBytes(currentWriterSchema.fixedSize)) - else -> throw unsupportedWriterTypeError(Schema.Type.BYTES, Schema.Type.FIXED) + Schema.Type.STRING -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readString(null).bytes) + else -> throw unsupportedWriterTypeError(Schema.Type.FIXED, Schema.Type.BYTES, Schema.Type.STRING) } } } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt index 75bb1b8e..9757a97b 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt @@ -8,6 +8,8 @@ import com.github.avrokotlin.avro4k.internal.SerializerLocatorMiddleware import com.github.avrokotlin.avro4k.internal.aliases import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch import com.github.avrokotlin.avro4k.internal.nonNullSerialName +import com.github.avrokotlin.avro4k.internal.toFloatExact +import com.github.avrokotlin.avro4k.internal.toIntExact import com.github.avrokotlin.avro4k.namedSchemaNotFoundInUnionError import com.github.avrokotlin.avro4k.trySelectEnumSchemaForSymbol import com.github.avrokotlin.avro4k.trySelectFixedSchemaForSize @@ -230,10 +232,8 @@ internal abstract class AbstractAvroEncoder : AbstractEncoder(), AvroEncoder { override fun encodeBoolean(value: Boolean) { if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { - trySelectTypeFromUnion( - Schema.Type.BOOLEAN, - Schema.Type.STRING - ) || throw typeNotFoundInUnionError(Schema.Type.BOOLEAN, Schema.Type.STRING) + trySelectTypeFromUnion(Schema.Type.BOOLEAN, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.BOOLEAN, Schema.Type.STRING) } when (currentWriterSchema.type) { Schema.Type.BOOLEAN -> encodeBooleanUnchecked(value) @@ -252,58 +252,27 @@ internal abstract class AbstractAvroEncoder : AbstractEncoder(), AvroEncoder { override fun encodeInt(value: Int) { if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { - trySelectTypeFromUnion( - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) || - throw typeNotFoundInUnionError( - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) + trySelectTypeFromUnion(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) } when (currentWriterSchema.type) { Schema.Type.INT -> encodeIntUnchecked(value) Schema.Type.LONG -> encodeLongUnchecked(value.toLong()) - Schema.Type.FLOAT -> encodeFloatUnchecked(value.toFloat()) - Schema.Type.DOUBLE -> encodeDoubleUnchecked(value.toDouble()) Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) - else -> throw unsupportedWriterTypeError( - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) + else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) } } override fun encodeLong(value: Long) { if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { - trySelectTypeFromUnion(Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) || - throw typeNotFoundInUnionError( - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) + trySelectTypeFromUnion(Schema.Type.LONG, Schema.Type.INT, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.LONG, Schema.Type.INT, Schema.Type.STRING) } when (currentWriterSchema.type) { + Schema.Type.INT -> encodeIntUnchecked(value.toIntExact()) Schema.Type.LONG -> encodeLongUnchecked(value) - Schema.Type.FLOAT -> encodeFloatUnchecked(value.toFloat()) - Schema.Type.DOUBLE -> encodeDoubleUnchecked(value.toDouble()) Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) - else -> throw unsupportedWriterTypeError( - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) + else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) } } @@ -322,13 +291,14 @@ internal abstract class AbstractAvroEncoder : AbstractEncoder(), AvroEncoder { override fun encodeDouble(value: Double) { if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { - trySelectTypeFromUnion(Schema.Type.DOUBLE, Schema.Type.STRING) || - throw typeNotFoundInUnionError(Schema.Type.DOUBLE, Schema.Type.STRING) + trySelectTypeFromUnion(Schema.Type.DOUBLE, Schema.Type.FLOAT, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.DOUBLE, Schema.Type.FLOAT, Schema.Type.STRING) } when (currentWriterSchema.type) { + Schema.Type.FLOAT -> encodeFloatUnchecked(value.toFloatExact()) Schema.Type.DOUBLE -> encodeDoubleUnchecked(value) Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) - else -> throw unsupportedWriterTypeError(Schema.Type.DOUBLE, Schema.Type.STRING) + else -> throw unsupportedWriterTypeError(Schema.Type.DOUBLE, Schema.Type.FLOAT, Schema.Type.STRING) } } @@ -349,21 +319,41 @@ internal abstract class AbstractAvroEncoder : AbstractEncoder(), AvroEncoder { trySelectTypeFromUnion(Schema.Type.STRING, Schema.Type.BYTES) || trySelectFixedSchemaForSize(value.length) || trySelectEnumSchemaForSymbol(value) || + trySelectTypeFromUnion( + Schema.Type.BOOLEAN, + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE) || throw typeNotFoundInUnionError( + Schema.Type.BOOLEAN, + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED, - Schema.Type.ENUM - ) + Schema.Type.ENUM) } when (currentWriterSchema.type) { + Schema.Type.BOOLEAN -> encodeBooleanUnchecked(value.toBooleanStrict()) + Schema.Type.INT -> encodeIntUnchecked(value.toInt()) + Schema.Type.LONG -> encodeLongUnchecked(value.toLong()) + Schema.Type.FLOAT -> encodeFloatUnchecked(value.toFloat()) + Schema.Type.DOUBLE -> encodeDoubleUnchecked(value.toDouble()) Schema.Type.STRING -> encodeStringUnchecked(Utf8(value)) Schema.Type.BYTES -> encodeBytesUnchecked(value.encodeToByteArray()) Schema.Type.FIXED -> encodeFixedUnchecked(ensureFixedSize(value.encodeToByteArray())) Schema.Type.ENUM -> encodeEnumUnchecked(value) else -> throw unsupportedWriterTypeError( - Schema.Type.BYTES, + Schema.Type.BOOLEAN, + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, Schema.Type.STRING, + Schema.Type.BYTES, Schema.Type.FIXED, Schema.Type.ENUM ) diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/schema/MapVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/schema/MapVisitor.kt index 2a6681f4..ea8e32ec 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/schema/MapVisitor.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/schema/MapVisitor.kt @@ -23,8 +23,8 @@ internal class MapVisitor( if (it.isNullable()) { throw AvroSchemaGenerationException("Map key cannot be nullable. Actual generated map key schema: $it") } - if (!it.isStringable()) { - throw AvroSchemaGenerationException("Map key must be string-able (boolean, number, enum, or string). Actual generated map key schema: $it") + if (!it.isNonNullScalarType()) { + throw AvroSchemaGenerationException("Map key must be a non-null scalar type (e.g. not a record, map or array). Actual generated map key schema: $it") } } @@ -40,7 +40,7 @@ internal class MapVisitor( } } -private fun Schema.isStringable(): Boolean = +internal fun Schema.isNonNullScalarType(): Boolean = when (type) { Schema.Type.BOOLEAN, Schema.Type.INT, @@ -49,18 +49,16 @@ private fun Schema.isStringable(): Boolean = Schema.Type.DOUBLE, Schema.Type.STRING, Schema.Type.ENUM, + Schema.Type.BYTES, + Schema.Type.FIXED, -> true Schema.Type.NULL, - // bytes could be stringified, but it's not a good idea as it can produce unreadable strings. - Schema.Type.BYTES, - // same, just bytes. Btw, if the user wants to stringify it, he can use @Contextual or custom @Serializable serializer. - Schema.Type.FIXED, Schema.Type.ARRAY, Schema.Type.MAP, Schema.Type.RECORD, null, -> false - Schema.Type.UNION -> types.all { it.isStringable() } + Schema.Type.UNION -> types.all { it.isNonNullScalarType() } } \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt index 8612b41b..1264cc35 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt @@ -13,7 +13,6 @@ import kotlinx.serialization.serializer import org.apache.avro.Conversions import org.apache.avro.Schema import org.apache.avro.SchemaBuilder -import org.apache.avro.generic.GenericContainer import org.apache.avro.generic.GenericData import org.apache.avro.generic.GenericDatumReader import org.apache.avro.generic.GenericDatumWriter @@ -218,45 +217,55 @@ fun encodeToBytesUsingApacheLib( } } -internal inline fun StringSpecRootScope.basicScalarEncodeDecodeTests(value: T, schema: Schema, apacheCompatibleValue: Any? = value) { - "support scalar type ${schema.type} serialization" { - Avro.schema() shouldBe schema - testEncodeDecode(schema, value, apacheCompatibleValue = apacheCompatibleValue) +internal inline fun StringSpecRootScope.basicScalarEncodeDecodeTests( + value: T, + expectedSchema: Schema, + apacheCompatibleValue: Any? = value, +) { + "support runtime ${expectedSchema.type} type ${value::class.qualifiedName} serialization" { + Avro.schema() shouldBe expectedSchema + testEncodeDecode(expectedSchema, value, apacheCompatibleValue = apacheCompatibleValue) - Avro.schema>() shouldBe schema - testEncodeDecode(schema, ValueClassWithGenericField(value), apacheCompatibleValue = apacheCompatibleValue) + Avro.schema>() shouldBe expectedSchema + testEncodeDecode(expectedSchema, ValueClassWithGenericField(value), apacheCompatibleValue = apacheCompatibleValue) } - "support scalar type ${schema.type} serialization as nullable" { - Avro.schema() shouldBe schema.nullable - testEncodeDecode(schema.nullable, value, apacheCompatibleValue = apacheCompatibleValue) - testEncodeDecode(schema.nullable, null) + "support runtime ${expectedSchema.type} type ${value::class.qualifiedName} serialization as nullable" { + Avro.schema() shouldBe expectedSchema.nullable + testEncodeDecode(expectedSchema.nullable, value, apacheCompatibleValue = apacheCompatibleValue) + testEncodeDecode(expectedSchema.nullable, null) - Avro.schema>() shouldBe schema.nullable - testEncodeDecode(schema.nullable, ValueClassWithGenericField(value), apacheCompatibleValue = apacheCompatibleValue) - testEncodeDecode(schema.nullable, ValueClassWithGenericField(null), apacheCompatibleValue = null) + Avro.schema>() shouldBe expectedSchema.nullable + testEncodeDecode(expectedSchema.nullable, ValueClassWithGenericField(value), apacheCompatibleValue = apacheCompatibleValue) + testEncodeDecode(expectedSchema.nullable, ValueClassWithGenericField(null), apacheCompatibleValue = null) - Avro.schema?>() shouldBe schema.nullable - testEncodeDecode?>(schema.nullable, null) + Avro.schema?>() shouldBe expectedSchema.nullable + testEncodeDecode?>(expectedSchema.nullable, null) - Avro.schema?>() shouldBe schema.nullable - testEncodeDecode?>(schema.nullable, null) + Avro.schema?>() shouldBe expectedSchema.nullable + testEncodeDecode?>(expectedSchema.nullable, null) } - "scalar type ${schema.type} in record" { + "support runtime ${expectedSchema.type} type ${value::class.qualifiedName} in record" { val record = SchemaBuilder.record("RecordWithGenericField").fields() - .name("field").type(schema).noDefault() + .name("field").type(expectedSchema).noDefault() .endRecord() Avro.schema>() shouldBe record Avro.schema>>() shouldBe record - testEncodeDecode(record, - RecordWithGenericField(value), apacheCompatibleValue = GenericData.Record(record).also { it.put(0, apacheCompatibleValue) }) - testEncodeDecode(record, - RecordWithGenericField(ValueClassWithGenericField(value)), apacheCompatibleValue = GenericData.Record(record).also { it.put(0, apacheCompatibleValue) }) + testEncodeDecode( + record, + RecordWithGenericField(value), + apacheCompatibleValue = GenericData.Record(record).also { it.put(0, apacheCompatibleValue) } + ) + testEncodeDecode( + record, + RecordWithGenericField(ValueClassWithGenericField(value)), + apacheCompatibleValue = GenericData.Record(record).also { it.put(0, apacheCompatibleValue) } + ) } - "scalar type ${schema.type} in record as nullable field" { + "support runtime ${expectedSchema.type} type ${value::class.qualifiedName} in record as nullable field" { val expectedRecordSchemaNullable = SchemaBuilder.record("RecordWithGenericField").fields() - .name("field").type(schema.nullable).withDefault(null) + .name("field").type(expectedSchema.nullable).withDefault(null) .endRecord() Avro.schema>() shouldBe expectedRecordSchemaNullable Avro.schema>>() shouldBe expectedRecordSchemaNullable @@ -265,22 +274,31 @@ internal inline fun StringSpecRootScope.basicScalarEncodeDecodeTests val recordNullable = SchemaBuilder.record("RecordWithGenericField").fields() - .name("field").type(schema.nullable).noDefault() + .name("field").type(expectedSchema.nullable).noDefault() .endRecord() - testEncodeDecode(recordNullable, - RecordWithGenericField(value), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, apacheCompatibleValue) }) - testEncodeDecode(recordNullable, - RecordWithGenericField(null), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, null) }) + testEncodeDecode( + recordNullable, + RecordWithGenericField(value), + apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, apacheCompatibleValue) } + ) + testEncodeDecode( + recordNullable, + RecordWithGenericField(null), + apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, null) } + ) testEncodeDecode( recordNullable, RecordWithGenericField(ValueClassWithGenericField(value)), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, apacheCompatibleValue) } ) - testEncodeDecode(recordNullable, - RecordWithGenericField(ValueClassWithGenericField(null)), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, null) }) + testEncodeDecode( + recordNullable, + RecordWithGenericField(ValueClassWithGenericField(null)), + apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, null) } + ) } - "scalar type ${schema.type} in map" { - val map = SchemaBuilder.map().values(schema) + "support runtime ${expectedSchema.type} type ${value::class.qualifiedName} in map" { + val map = SchemaBuilder.map().values(expectedSchema) Avro.schema>() shouldBe map Avro.schema>>() shouldBe map Avro.schema>>() shouldBe map @@ -290,7 +308,7 @@ internal inline fun StringSpecRootScope.basicScalarEncodeDecodeTests testEncodeDecode(map, mapOf("key" to value), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) testEncodeDecode(map, mapOf("key" to ValueClassWithGenericField(value)), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) - val mapNullable = SchemaBuilder.map().values(schema.nullable) + val mapNullable = SchemaBuilder.map().values(expectedSchema.nullable) Avro.schema>() shouldBe mapNullable Avro.schema>>() shouldBe mapNullable Avro.schema?>>() shouldBe mapNullable @@ -306,8 +324,8 @@ internal inline fun StringSpecRootScope.basicScalarEncodeDecodeTests testEncodeDecode(mapNullable, mapOf("key" to ValueClassWithGenericField(value)), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) testEncodeDecode(mapNullable, mapOf("key" to ValueClassWithGenericField(null)), apacheCompatibleValue = mapOf("key" to null)) } - "scalar type ${schema.type} in array" { - val array = SchemaBuilder.array().items(schema) + "support runtime ${expectedSchema.type} type ${value::class.qualifiedName} in array" { + val array = SchemaBuilder.array().items(expectedSchema) Avro.schema>() shouldBe array Avro.schema>>() shouldBe array Avro.schema>() shouldBe array @@ -317,7 +335,7 @@ internal inline fun StringSpecRootScope.basicScalarEncodeDecodeTests testEncodeDecode(array, listOf(value), apacheCompatibleValue = listOf(apacheCompatibleValue)) testEncodeDecode(array, listOf(ValueClassWithGenericField(value)), apacheCompatibleValue = listOf(apacheCompatibleValue)) - val arrayNullable = SchemaBuilder.array().items(schema.nullable) + val arrayNullable = SchemaBuilder.array().items(expectedSchema.nullable) Avro.schema>() shouldBe arrayNullable Avro.schema>>() shouldBe arrayNullable Avro.schema?>>() shouldBe arrayNullable @@ -335,16 +353,85 @@ internal inline fun StringSpecRootScope.basicScalarEncodeDecodeTests } } -internal inline fun StringSpecRootScope.testSerializationTypeCompatibility(logicalValue: T, encodedAsValue: R) { - val schema = when { - encodedAsValue is GenericContainer -> encodedAsValue.schema - else -> Avro.schema() +internal inline fun StringSpecRootScope.testSerializationTypeCompatibility( + logicalValue: T, + apacheCompatibleValue: R, + writerSchema: Schema, +) { + val originalSchema = Avro.schema() + "support coercion from ${originalSchema.type} runtime type ${logicalValue::class.qualifiedName} to type ${writerSchema.type}" { + testEncodeDecode(writerSchema, logicalValue, apacheCompatibleValue = apacheCompatibleValue) + testEncodeDecode(writerSchema, ValueClassWithGenericField(logicalValue), apacheCompatibleValue = apacheCompatibleValue) + } + "support coercion from ${originalSchema.type} runtime type ${logicalValue::class.qualifiedName} to nullable type ${writerSchema.type}" { + testEncodeDecode(writerSchema.nullable, logicalValue, apacheCompatibleValue = apacheCompatibleValue) + testEncodeDecode(writerSchema.nullable, null) + + testEncodeDecode(writerSchema.nullable, ValueClassWithGenericField(logicalValue), apacheCompatibleValue = apacheCompatibleValue) + testEncodeDecode(writerSchema.nullable, ValueClassWithGenericField(null), apacheCompatibleValue = null) + + testEncodeDecode?>(writerSchema.nullable, null) + testEncodeDecode?>(writerSchema.nullable, null) + } + "support coercion from ${originalSchema.type} runtime type ${logicalValue::class.qualifiedName} inside a record's field with type ${writerSchema.type}" { + val record = + SchemaBuilder.record("RecordWithGenericField").fields() + .name("field").type(writerSchema).noDefault() + .endRecord() + testEncodeDecode( + record, + RecordWithGenericField(logicalValue), + apacheCompatibleValue = GenericData.Record(record).also { it.put(0, apacheCompatibleValue) } + ) + testEncodeDecode( + record, + RecordWithGenericField(ValueClassWithGenericField(logicalValue)), + apacheCompatibleValue = GenericData.Record(record).also { it.put(0, apacheCompatibleValue) } + ) + } + "support coercion from ${originalSchema.type} runtime type ${logicalValue::class.qualifiedName} inside a record's field with nullable type ${writerSchema.type}" { + val recordNullable = + SchemaBuilder.record("RecordWithGenericField").fields() + .name("field").type(writerSchema.nullable).noDefault() + .endRecord() + testEncodeDecode( + recordNullable, + RecordWithGenericField(logicalValue), + apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, apacheCompatibleValue) } + ) + testEncodeDecode( + recordNullable, + RecordWithGenericField(null), + apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, null) } + ) + testEncodeDecode( + recordNullable, + RecordWithGenericField(ValueClassWithGenericField(logicalValue)), + apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, apacheCompatibleValue) } + ) + testEncodeDecode( + recordNullable, + RecordWithGenericField(ValueClassWithGenericField(null)), + apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, null) } + ) } - "Support ${logicalValue::class.simpleName} serialization as ${schema.type}" { - testEncodeDecode(schema, logicalValue, apacheCompatibleValue = encodedAsValue) + "support coercion from ${originalSchema.type} runtime type ${logicalValue::class.qualifiedName} inside a map of ${writerSchema.type} values" { + val map = SchemaBuilder.map().values(writerSchema) + testEncodeDecode(map, mapOf("key" to logicalValue), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) + testEncodeDecode(map, mapOf("key" to ValueClassWithGenericField(logicalValue)), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) + + val mapNullable = SchemaBuilder.map().values(writerSchema.nullable) + testEncodeDecode(mapNullable, mapOf("key" to ValueClassWithGenericField(logicalValue)), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) + testEncodeDecode(mapNullable, mapOf("key" to ValueClassWithGenericField(null)), apacheCompatibleValue = mapOf("key" to null)) } - "Support ${logicalValue::class.simpleName} serialization as nullable ${schema.type}" { - testEncodeDecode(schema.nullable, logicalValue, apacheCompatibleValue = encodedAsValue) + "support coercion from ${originalSchema.type} runtime type ${logicalValue::class.qualifiedName} inside an array of ${writerSchema.type} items" { + val array = SchemaBuilder.array().items(writerSchema) + testEncodeDecode(array, listOf(logicalValue), apacheCompatibleValue = listOf(apacheCompatibleValue)) + testEncodeDecode(array, listOf(ValueClassWithGenericField(logicalValue)), apacheCompatibleValue = listOf(apacheCompatibleValue)) + + val arrayNullable = SchemaBuilder.array().items(writerSchema.nullable) + testEncodeDecode(arrayNullable, listOf(ValueClassWithGenericField(logicalValue)), apacheCompatibleValue = listOf(apacheCompatibleValue)) + testEncodeDecode(arrayNullable, listOf(ValueClassWithGenericField(null)), apacheCompatibleValue = listOf(null)) } } @@ -357,5 +444,42 @@ inline fun testEncodeDecode( expectedBytes: ByteArray = encodeToBytesUsingApacheLib(schema, apacheCompatibleValue), ) { Avro.encodeToByteArray(schema, serializer, toEncode) shouldBe expectedBytes - Avro.decodeFromByteArray(schema, serializer, expectedBytes) shouldBe decoded + val decodedValue = Avro.decodeFromByteArray(schema, serializer, expectedBytes) as Any? + try { + decodedValue shouldBe decoded + } catch (originalError: Throwable) { + if (!deepEquals(decodedValue, decoded)) { + throw originalError + } + } +} + +/** + * kotest doesn't handle deep equals for value classes, so we need to implement it ourselves. + */ +fun deepEquals( + a: Any?, + b: Any?, +): Boolean { + if (a === b) return true + if (a == b) return true + if (a == null || b == null) return false + if (a is ByteArray && b is ByteArray) return a.contentEquals(b) + if (a is ValueClassWithGenericField<*> && b is ValueClassWithGenericField<*>) return deepEquals(a.value, b.value) + if (a is Collection<*> && b is Collection<*>) { + if (a.size != b.size) return false + return a.zip(b).all { (a, b) -> deepEquals(a, b) } + } + if (a is Map<*, *> && b is Map<*, *>) { + if (a.size != b.size) return false + return a.all { (key, value) -> deepEquals(value, b[key]) } + } + if (a::class.isValue && b::class.isValue) { + return deepEquals(unboxValueClass(a), unboxValueClass(b)) + } + return a == b +} + +private fun unboxValueClass(value: Any): Any? { + return value::class.java.getDeclaredMethod("unbox-impl").apply { isAccessible = true }.invoke(value) } \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/dataClassesForTests.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/dataClassesForTests.kt index 6a6a2d62..dfb59ad0 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/dataClassesForTests.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/dataClassesForTests.kt @@ -12,11 +12,38 @@ internal enum class SomeEnum { @Serializable @SerialName("RecordWithGenericField") -internal data class RecordWithGenericField(val field: T) +internal data class RecordWithGenericField(val field: T) { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + + other as RecordWithGenericField<*> + + return deepEquals(field, other.field) + } + + override fun hashCode(): Int { + return field?.hashCode() ?: 0 + } + + override fun toString(): String { + if (field is ByteArray) { + return "RecordWithGenericField(field=${field.contentToString()})" + } + return "RecordWithGenericField(field=$field)" + } +} @JvmInline @Serializable -internal value class ValueClassWithGenericField(val value: T) +internal value class ValueClassWithGenericField(val value: T) { + override fun toString(): String { + if (value is ByteArray) { + return "ValueClassWithGenericField(value=${value.contentToString()})" + } + return "ValueClassWithGenericField(value=$value)" + } +} @Serializable @JvmInline diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt index 2280233e..32e3ed69 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt @@ -31,10 +31,7 @@ import org.apache.avro.generic.GenericData internal class EnumTest : StringSpec({ val expectedEnumSchema = SchemaBuilder.enumeration(Cream::class.qualifiedName).aliases("TheCream").doc("documentation").symbols("Bruce", "Baker", "Clapton") basicScalarEncodeDecodeTests(Cream.Bruce, expectedEnumSchema, apacheCompatibleValue = GenericData.EnumSymbol(expectedEnumSchema, "Bruce")) - testSerializationTypeCompatibility(Cream.Baker, "Baker") - - // TODO test alias decoding - // TODO test decoding from union (name resolution) + testSerializationTypeCompatibility(Cream.Baker, "Baker", Schema.create(Schema.Type.STRING)) "Only allow 1 @AvroEnumDefault at max" { shouldThrow { @@ -93,23 +90,30 @@ internal class EnumTest : StringSpec({ .isEncodedAs(record(GenericData.EnumSymbol(writerSchema.fields[0].schema(), "A")), writerSchema = writerSchema) } - "support alias on enum inside an union" { - val writerSchema = - SchemaBuilder.record("EnumWrapperRecord").fields() - .name("value") - .type( - Schema.createUnion( - SchemaBuilder.enumeration("OtherEnum").symbols("OTHER"), - SchemaBuilder.record("UnknownRecord").aliases("RecordA") - .fields().name("field").type().stringType().noDefault() - .endRecord(), - SchemaBuilder.enumeration("UnknownEnum").aliases("com.github.avrokotlin.avro4k.SomeEnum").symbols("A", "B", "C") - ) - ) - .noDefault() - .endRecord() - AvroAssertions.assertThat(EnumWrapperRecord(SomeEnum.A)) - .isEncodedAs(record(GenericData.EnumSymbol(writerSchema.fields[0].schema().types[2], "A")), writerSchema = writerSchema) + "support decoding enum inside an union with an alias on the reader schema" { + val writerSchema = Schema.createUnion( + SchemaBuilder.enumeration("OtherEnum").symbols("OTHER"), + SchemaBuilder.record("UnknownRecord").aliases("RecordA") + .fields().name("field").type().stringType().noDefault() + .endRecord(), + SchemaBuilder.enumeration("TheEnum").symbols("A", "B", "C") + ) + + AvroAssertions.assertThat(EnumWithAlias.A) + .isEncodedAs(GenericData.EnumSymbol(writerSchema.types[2], "A"), writerSchema = writerSchema) + } + + "support decoding enum inside an union with an alias on the writer schema" { + val writerSchema = Schema.createUnion( + SchemaBuilder.enumeration("OtherEnum").symbols("OTHER"), + SchemaBuilder.record("UnknownRecord").aliases("RecordA") + .fields().name("field").type().stringType().noDefault() + .endRecord(), + SchemaBuilder.enumeration("UnknownEnum").aliases("com.github.avrokotlin.avro4k.SomeEnum").symbols("A", "B", "C") + ) + + AvroAssertions.assertThat(SomeEnum.A) + .isEncodedAs(GenericData.EnumSymbol(writerSchema.types[2], "A"), writerSchema = writerSchema) } }) { @Serializable @@ -171,4 +175,13 @@ internal class EnumTest : StringSpec({ Baker, Clapton, } + + @Serializable + @AvroAlias("TheEnum") + @SerialName("EnumWithAlias") + private enum class EnumWithAlias { + A, + B, + C, + } } \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt index 2b8f7eda..b4c609e5 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt @@ -1,184 +1,87 @@ package com.github.avrokotlin.avro4k.encoding -import com.github.avrokotlin.avro4k.AvroAssertions -import com.github.avrokotlin.avro4k.SomeEnum -import com.github.avrokotlin.avro4k.WrappedBoolean -import com.github.avrokotlin.avro4k.WrappedByte -import com.github.avrokotlin.avro4k.WrappedChar -import com.github.avrokotlin.avro4k.WrappedDouble -import com.github.avrokotlin.avro4k.WrappedFloat -import com.github.avrokotlin.avro4k.WrappedInt -import com.github.avrokotlin.avro4k.WrappedLong -import com.github.avrokotlin.avro4k.WrappedShort -import com.github.avrokotlin.avro4k.WrappedString -import com.github.avrokotlin.avro4k.internal.nullable -import com.github.avrokotlin.avro4k.record +import com.github.avrokotlin.avro4k.AvroFixed +import com.github.avrokotlin.avro4k.basicScalarEncodeDecodeTests +import com.github.avrokotlin.avro4k.internal.copy +import com.github.avrokotlin.avro4k.testSerializationTypeCompatibility import io.kotest.core.spec.style.StringSpec -import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.Serializable -import kotlinx.serialization.serializer +import org.apache.avro.LogicalType import org.apache.avro.Schema +import org.apache.avro.generic.GenericData import java.nio.ByteBuffer internal class PrimitiveEncodingTest : StringSpec({ - "read write out booleans" { - AvroAssertions.assertThat(BooleanTest(true)) - .isEncodedAs(record(true)) - AvroAssertions.assertThat(BooleanTest(false)) - .isEncodedAs(record(false)) - AvroAssertions.assertThat(true) - .isEncodedAs(true) - AvroAssertions.assertThat(false) - .isEncodedAs(false) - AvroAssertions.assertThat(WrappedBoolean(true)) - .isEncodedAs(true) - AvroAssertions.assertThat(WrappedBoolean(false)) - .isEncodedAs(false) - } - - @OptIn(InternalSerializationApi::class) - listOf( - true, - false, - 1.toByte(), - 2.toShort(), - 3, - 4L, - 5.0F, - 6.0, - 'A', - SomeEnum.B - ).forEach { - "coerce ${it::class.simpleName} $it to string" { - AvroAssertions.assertThat(it, it::class.serializer()) - .isEncodedAs(it.toString(), writerSchema = Schema.create(Schema.Type.STRING)) - } - - "coerce ${it::class.simpleName} $it to nullable string" { - AvroAssertions.assertThat(it, it::class.serializer()) - .isEncodedAs(it.toString(), writerSchema = Schema.create(Schema.Type.STRING).nullable) - } - } - - "read write out bytes" { - AvroAssertions.assertThat(ByteTest(3)) - .isEncodedAs(record(3)) - AvroAssertions.assertThat(3.toByte()) - .isEncodedAs(3) - AvroAssertions.assertThat(WrappedByte(3)) - .isEncodedAs(3) - } - - "read write out shorts" { - AvroAssertions.assertThat(ShortTest(3)) - .isEncodedAs(record(3)) - AvroAssertions.assertThat(3.toShort()) - .isEncodedAs(3) - AvroAssertions.assertThat(WrappedShort(3)) - .isEncodedAs(3) - } - - "read write out chars" { - AvroAssertions.assertThat(CharTest('A')) - .isEncodedAs(record('A'.code)) - AvroAssertions.assertThat('A') - .isEncodedAs('A'.code) - AvroAssertions.assertThat(WrappedChar('A')) - .isEncodedAs('A'.code) - } - - "read write out strings" { - AvroAssertions.assertThat(StringTest("Hello world")) - .isEncodedAs(record("Hello world")) - AvroAssertions.assertThat("Hello world") - .isEncodedAs("Hello world") - AvroAssertions.assertThat(WrappedString("Hello world")) - .isEncodedAs("Hello world") - } - - "read write out longs" { - AvroAssertions.assertThat(LongTest(65653L)) - .isEncodedAs(record(65653L)) - AvroAssertions.assertThat(65653L) - .isEncodedAs(65653L) - AvroAssertions.assertThat(WrappedLong(65653)) - .isEncodedAs(65653L) - } - - "read write out ints" { - AvroAssertions.assertThat(IntTest(44)) - .isEncodedAs(record(44)) - AvroAssertions.assertThat(44) - .isEncodedAs(44) - AvroAssertions.assertThat(WrappedInt(44)) - .isEncodedAs(44) - } - - "read write out doubles" { - AvroAssertions.assertThat(DoubleTest(3.235)) - .isEncodedAs(record(3.235)) - AvroAssertions.assertThat(3.235) - .isEncodedAs(3.235) - AvroAssertions.assertThat(WrappedDouble(3.235)) - .isEncodedAs(3.235) - } - - "read write out floats" { - AvroAssertions.assertThat(FloatTest(3.4F)) - .isEncodedAs(record(3.4F)) - AvroAssertions.assertThat(3.4F) - .isEncodedAs(3.4F) - AvroAssertions.assertThat(WrappedFloat(3.4F)) - .isEncodedAs(3.4F) - } - - "read write out byte arrays" { - AvroAssertions.assertThat(ByteArrayTest("ABC".toByteArray())) - .isEncodedAs(record(ByteBuffer.wrap("ABC".toByteArray()))) - AvroAssertions.assertThat("ABC".toByteArray()) - .isEncodedAs(ByteBuffer.wrap("ABC".toByteArray())) - } + // boolean can be encoded to boolean or string + basicScalarEncodeDecodeTests(true, Schema.create(Schema.Type.BOOLEAN)) + testSerializationTypeCompatibility(true, "true", Schema.create(Schema.Type.STRING)) + basicScalarEncodeDecodeTests(false, Schema.create(Schema.Type.BOOLEAN)) + testSerializationTypeCompatibility(false, "false", Schema.create(Schema.Type.STRING)) + + // byte can be encoded to int, long or string + basicScalarEncodeDecodeTests(1.toByte(), Schema.create(Schema.Type.INT), apacheCompatibleValue = 1) + testSerializationTypeCompatibility(1.toByte(), "1", Schema.create(Schema.Type.STRING)) + testSerializationTypeCompatibility(1.toByte(), 1L, Schema.create(Schema.Type.LONG)) + + // short can be encoded to int, long or string + basicScalarEncodeDecodeTests(2.toShort(), Schema.create(Schema.Type.INT), apacheCompatibleValue = 2) + testSerializationTypeCompatibility(2.toShort(), "2", Schema.create(Schema.Type.STRING)) + testSerializationTypeCompatibility(2.toShort(), 2L, Schema.create(Schema.Type.LONG)) + + // int can be encoded to int, long or string + basicScalarEncodeDecodeTests(3, Schema.create(Schema.Type.INT)) + testSerializationTypeCompatibility(3, "3", Schema.create(Schema.Type.STRING)) + testSerializationTypeCompatibility(3, 3L, Schema.create(Schema.Type.LONG)) + + // long can be encoded to int, long or string + basicScalarEncodeDecodeTests(4L, Schema.create(Schema.Type.LONG)) + testSerializationTypeCompatibility(4L, "4", Schema.create(Schema.Type.STRING)) + testSerializationTypeCompatibility(4L, 4, Schema.create(Schema.Type.INT)) + + // float can be encoded to double, float or string + basicScalarEncodeDecodeTests(5.0F, Schema.create(Schema.Type.FLOAT)) + testSerializationTypeCompatibility(5.0F, "5.0", Schema.create(Schema.Type.STRING)) + testSerializationTypeCompatibility(5.0F, 5.0, Schema.create(Schema.Type.DOUBLE)) + + // double can be encoded to double, float or string + basicScalarEncodeDecodeTests(6.0, Schema.create(Schema.Type.DOUBLE)) + testSerializationTypeCompatibility(6.0, "6.0", Schema.create(Schema.Type.STRING)) + testSerializationTypeCompatibility(6.0, 6.0F, Schema.create(Schema.Type.FLOAT)) + + // char can be encoded to int or string + basicScalarEncodeDecodeTests('A', Schema.create(Schema.Type.INT).copy(logicalType = LogicalType("char")), apacheCompatibleValue = 'A'.code) + testSerializationTypeCompatibility('A', "A", Schema.create(Schema.Type.STRING)) + + // bytes can be encoded to bytes, string, or fixed + val bytesValue = "test".encodeToByteArray() + basicScalarEncodeDecodeTests(bytesValue, Schema.create(Schema.Type.BYTES), apacheCompatibleValue = ByteBuffer.wrap(bytesValue)) + testSerializationTypeCompatibility(bytesValue, "test", Schema.create(Schema.Type.STRING)) + val fixedBytesSchema = Schema.createFixed("fixed", null, null, bytesValue.size) + testSerializationTypeCompatibility(bytesValue, GenericData.Fixed(fixedBytesSchema, bytesValue), fixedBytesSchema) + + // fixed can be encoded to bytes, string, or fixed + val fixedValue = FixedValue("fixed".encodeToByteArray()) + val fixedSchema = Schema.createFixed("fixed", null, null, fixedValue.value.size) + // not able to directly serialize a fixed, so we need to wrap it in a value class to indicates it's a fixed type + testSerializationTypeCompatibility(fixedValue, ByteBuffer.wrap(fixedValue.value), Schema.create(Schema.Type.BYTES)) + testSerializationTypeCompatibility(fixedValue, "fixed", Schema.create(Schema.Type.STRING)) + testSerializationTypeCompatibility(fixedValue, GenericData.Fixed(fixedSchema, fixedValue.value), fixedSchema) + + // string can be encoded to bytes, string, or fixed + val stringValue = "the string content" + basicScalarEncodeDecodeTests(stringValue, Schema.create(Schema.Type.STRING)) + testSerializationTypeCompatibility(stringValue, ByteBuffer.wrap(stringValue.encodeToByteArray()), Schema.create(Schema.Type.BYTES)) + val fixedStringSchema = Schema.createFixed("fixed", null, null, stringValue.length) + testSerializationTypeCompatibility(stringValue, GenericData.Fixed(fixedStringSchema, stringValue.encodeToByteArray()), fixedStringSchema) + + testSerializationTypeCompatibility("true", true, Schema.create(Schema.Type.BOOLEAN)) + testSerializationTypeCompatibility("false", false, Schema.create(Schema.Type.BOOLEAN)) + testSerializationTypeCompatibility("23", 23, Schema.create(Schema.Type.INT)) + testSerializationTypeCompatibility("55", 55L, Schema.create(Schema.Type.LONG)) + testSerializationTypeCompatibility("5.3", 5.3F, Schema.create(Schema.Type.FLOAT)) + testSerializationTypeCompatibility("5.3", 5.3, Schema.create(Schema.Type.DOUBLE)) }) { + @JvmInline @Serializable - data class BooleanTest(val z: Boolean) - - @Serializable - data class ByteTest(val z: Byte) - - @Serializable - data class ShortTest(val z: Short) - - @Serializable - data class CharTest(val z: Char) - - @Serializable - data class StringTest(val z: String) - - @Serializable - data class FloatTest(val z: Float) - - @Serializable - data class DoubleTest(val z: Double) - - @Serializable - data class IntTest(val z: Int) - - @Serializable - data class LongTest(val z: Long) - - @Serializable - data class ByteArrayTest(val z: ByteArray) { - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (javaClass != other?.javaClass) return false - - other as ByteArrayTest - - return z.contentEquals(other.z) - } - - override fun hashCode(): Int { - return z.contentHashCode() - } - } + private value class FixedValue(@AvroFixed(5) val value: ByteArray) } \ No newline at end of file From beb31aa271c83843d0fbedc90948b49e61be0520 Mon Sep 17 00:00:00 2001 From: Chuckame Date: Thu, 10 Oct 2024 22:31:52 +0200 Subject: [PATCH 08/13] tests: improve test coverage for primitive & enum tests --- .../avrokotlin/avro4k/internal/NumberUtils.kt | 36 ++++++++++--------- .../internal/encoder/AbstractAvroEncoder.kt | 6 ++-- .../avrokotlin/avro4k/encoding/EnumTest.kt | 30 ++++++++-------- .../avro4k/encoding/PrimitiveEncodingTest.kt | 4 ++- 4 files changed, 42 insertions(+), 34 deletions(-) diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/NumberUtils.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/NumberUtils.kt index f1e67803..b5839eee 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/NumberUtils.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/NumberUtils.kt @@ -5,86 +5,88 @@ import java.math.BigDecimal internal fun Int.toByteExact(): Byte { if (this.toByte().toInt() != this) { - throw SerializationException("Value $this is not a valid Byte") + return invalidTypeError() } return this.toByte() } internal fun Long.toByteExact(): Byte { if (this.toByte().toLong() != this) { - throw SerializationException("Value $this is not a valid Byte") + return invalidTypeError() } return this.toByte() } internal fun BigDecimal.toByteExact(): Byte { - if (this.toInt().toByte().toInt().toBigDecimal() != this) { - throw SerializationException("Value $this is not a valid Byte") + if (this.toByte().toInt().toBigDecimal() != this) { + return invalidTypeError() } - return this.toInt().toByte() + return this.toByte() } internal fun Int.toShortExact(): Short { if (this.toShort().toInt() != this) { - throw SerializationException("Value $this is not a valid Short") + return invalidTypeError() } return this.toShort() } internal fun Long.toShortExact(): Short { if (this.toShort().toLong() != this) { - throw SerializationException("Value $this is not a valid Short") + return invalidTypeError() } return this.toShort() } internal fun BigDecimal.toShortExact(): Short { - if (this.toInt().toShort().toInt().toBigDecimal() != this) { - throw SerializationException("Value $this is not a valid Short") + if (this.toShort().toInt().toBigDecimal() != this) { + return invalidTypeError() } - return this.toInt().toShort() + return this.toShort() } internal fun Long.toIntExact(): Int { if (this.toInt().toLong() != this) { - throw invalidType() + return invalidTypeError() } return this.toInt() } internal fun BigDecimal.toIntExact(): Int { if (this.toInt().toBigDecimal() != this) { - throw invalidType() + return invalidTypeError() } return this.toInt() } internal fun BigDecimal.toLongExact(): Long { if (this.toLong().toBigDecimal() != this) { - throw invalidType() + return invalidTypeError() } return this.toLong() } internal fun Double.toFloatExact(): Float { if (this.toFloat().toDouble() != this) { - throw invalidType() + return invalidTypeError() } return this.toFloat() } internal fun BigDecimal.toDoubleExact(): Double { if (this.toDouble().toBigDecimal() != this) { - throw invalidType() + return invalidTypeError() } return this.toDouble() } internal fun BigDecimal.toFloatExact(): Float { if (this.toFloat().toBigDecimal() != this) { - throw invalidType() + return invalidTypeError() } return this.toFloat() } -private inline fun Any.invalidType() = SerializationException("Value $this is not a valid ${T::class.simpleName}") \ No newline at end of file +private inline fun Any.invalidTypeError(): T { + throw SerializationException("Value $this is not a valid ${T::class.simpleName}") +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt index 9757a97b..ebd0ddbf 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt @@ -324,7 +324,8 @@ internal abstract class AbstractAvroEncoder : AbstractEncoder(), AvroEncoder { Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, - Schema.Type.DOUBLE) || + Schema.Type.DOUBLE + ) || throw typeNotFoundInUnionError( Schema.Type.BOOLEAN, Schema.Type.INT, @@ -334,7 +335,8 @@ internal abstract class AbstractAvroEncoder : AbstractEncoder(), AvroEncoder { Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED, - Schema.Type.ENUM) + Schema.Type.ENUM + ) } when (currentWriterSchema.type) { Schema.Type.BOOLEAN -> encodeBooleanUnchecked(value.toBooleanStrict()) diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt index 32e3ed69..501b3dba 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt @@ -91,26 +91,28 @@ internal class EnumTest : StringSpec({ } "support decoding enum inside an union with an alias on the reader schema" { - val writerSchema = Schema.createUnion( - SchemaBuilder.enumeration("OtherEnum").symbols("OTHER"), - SchemaBuilder.record("UnknownRecord").aliases("RecordA") - .fields().name("field").type().stringType().noDefault() - .endRecord(), - SchemaBuilder.enumeration("TheEnum").symbols("A", "B", "C") - ) + val writerSchema = + Schema.createUnion( + SchemaBuilder.enumeration("OtherEnum").symbols("OTHER"), + SchemaBuilder.record("UnknownRecord").aliases("RecordA") + .fields().name("field").type().stringType().noDefault() + .endRecord(), + SchemaBuilder.enumeration("TheEnum").symbols("A", "B", "C") + ) AvroAssertions.assertThat(EnumWithAlias.A) .isEncodedAs(GenericData.EnumSymbol(writerSchema.types[2], "A"), writerSchema = writerSchema) } "support decoding enum inside an union with an alias on the writer schema" { - val writerSchema = Schema.createUnion( - SchemaBuilder.enumeration("OtherEnum").symbols("OTHER"), - SchemaBuilder.record("UnknownRecord").aliases("RecordA") - .fields().name("field").type().stringType().noDefault() - .endRecord(), - SchemaBuilder.enumeration("UnknownEnum").aliases("com.github.avrokotlin.avro4k.SomeEnum").symbols("A", "B", "C") - ) + val writerSchema = + Schema.createUnion( + SchemaBuilder.enumeration("OtherEnum").symbols("OTHER"), + SchemaBuilder.record("UnknownRecord").aliases("RecordA") + .fields().name("field").type().stringType().noDefault() + .endRecord(), + SchemaBuilder.enumeration("UnknownEnum").aliases("com.github.avrokotlin.avro4k.SomeEnum").symbols("A", "B", "C") + ) AvroAssertions.assertThat(SomeEnum.A) .isEncodedAs(GenericData.EnumSymbol(writerSchema.types[2], "A"), writerSchema = writerSchema) diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt index b4c609e5..dc5ec064 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt @@ -83,5 +83,7 @@ internal class PrimitiveEncodingTest : StringSpec({ }) { @JvmInline @Serializable - private value class FixedValue(@AvroFixed(5) val value: ByteArray) + private value class FixedValue( + @AvroFixed(5) val value: ByteArray, + ) } \ No newline at end of file From 5f4782ecc62d3ba4919352959f89bf39d55fd264 Mon Sep 17 00:00:00 2001 From: Chuckame Date: Thu, 10 Oct 2024 22:31:56 +0200 Subject: [PATCH 09/13] fix: do not decode int as boolean in generic decoder --- .../decoder/generic/AbstractAvroGenericDecoder.kt | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/AbstractAvroGenericDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/AbstractAvroGenericDecoder.kt index c9f91a41..4e2fd67f 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/AbstractAvroGenericDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/AbstractAvroGenericDecoder.kt @@ -96,8 +96,6 @@ internal abstract class AbstractAvroGenericDecoder : AbstractDecoder(), AvroDeco override fun decodeBoolean(): Boolean { return when (val value = decodeValue()) { is Boolean -> value - 1 -> true - 0 -> false is CharSequence -> value.toString().toBoolean() else -> throw BadDecodedValueError(value, PrimitiveKind.BOOLEAN, Boolean::class, Int::class, CharSequence::class) } @@ -164,10 +162,9 @@ internal abstract class AbstractAvroGenericDecoder : AbstractDecoder(), AvroDeco } override fun decodeChar(): Char { - val value = decodeValue() - return when { - value is Int -> value.toChar() - value is CharSequence && value.length == 1 -> value[0] + return when (val value = decodeValue()) { + is Int -> value.toChar() + is CharSequence -> value.single() else -> throw BadDecodedValueError(value, PrimitiveKind.CHAR, Int::class, CharSequence::class) } } From b4bbfe1844150bfec50384447300b6de3ea36ae1 Mon Sep 17 00:00:00 2001 From: Chuckame Date: Thu, 10 Oct 2024 22:39:42 +0200 Subject: [PATCH 10/13] tests: improve test coverage for enum --- .../com/github/avrokotlin/avro4k/encoding/EnumTest.kt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt index 501b3dba..b384b1d7 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumTest.kt @@ -77,6 +77,14 @@ internal class EnumTest : StringSpec({ } } + "Encoding enum with the wrong name fails" { + val schema = SchemaBuilder.enumeration("WrongName").symbols("A") + + shouldThrow { + Avro.encodeToByteArray(schema, EnumV1WithoutDefault.A) + } + } + "support alias on enum" { val writerSchema = SchemaBuilder.record("EnumWrapperRecord").fields() From 562a4070c1d8a82e8529c147a9065ae48cd59144 Mon Sep 17 00:00:00 2001 From: Chuckame Date: Mon, 3 Feb 2025 23:00:37 +0100 Subject: [PATCH 11/13] tests: should throw error when encoding null to non-nullable schema --- .../avro4k/encoding/PrimitiveEncodingTest.kt | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt index dc5ec064..9caad976 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt @@ -1,11 +1,17 @@ package com.github.avrokotlin.avro4k.encoding +import com.github.avrokotlin.avro4k.Avro import com.github.avrokotlin.avro4k.AvroFixed import com.github.avrokotlin.avro4k.basicScalarEncodeDecodeTests import com.github.avrokotlin.avro4k.internal.copy import com.github.avrokotlin.avro4k.testSerializationTypeCompatibility +import io.kotest.assertions.throwables.shouldThrow import io.kotest.core.spec.style.StringSpec import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException +import kotlinx.serialization.builtins.ByteArraySerializer +import kotlinx.serialization.builtins.nullable +import kotlinx.serialization.builtins.serializer import org.apache.avro.LogicalType import org.apache.avro.Schema import org.apache.avro.generic.GenericData @@ -80,6 +86,22 @@ internal class PrimitiveEncodingTest : StringSpec({ testSerializationTypeCompatibility("55", 55L, Schema.create(Schema.Type.LONG)) testSerializationTypeCompatibility("5.3", 5.3F, Schema.create(Schema.Type.FLOAT)) testSerializationTypeCompatibility("5.3", 5.3, Schema.create(Schema.Type.DOUBLE)) + + listOf( + Schema.Type.STRING to String.serializer(), + Schema.Type.BYTES to ByteArraySerializer(), + Schema.Type.INT to Int.serializer(), + Schema.Type.LONG to Long.serializer(), + Schema.Type.FLOAT to Float.serializer(), + Schema.Type.DOUBLE to Double.serializer(), + Schema.Type.BOOLEAN to Boolean.serializer() + ).forEach { (type, serializer) -> + "null can be encoded to non-nullable avro type $type" { + shouldThrow { + Avro.encodeToByteArray(Schema.create(type), serializer.nullable, null) + } + } + } }) { @JvmInline @Serializable From d619bf79468dffca93b8ac6343bcbeea5abfb466 Mon Sep 17 00:00:00 2001 From: Chuckame Date: Tue, 4 Feb 2025 00:02:03 +0100 Subject: [PATCH 12/13] feat: allow decoding string from enum --- .../internal/decoder/direct/AbstractAvroDirectDecoder.kt | 4 +++- .../avro4k/internal/encoder/direct/CollectionsEncoder.kt | 5 +---- .../avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt | 4 +++- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt index 76421907..5e508bab 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt @@ -189,6 +189,7 @@ internal abstract class AbstractAvroDirectDecoder( Schema.Type.LONG -> binaryDecoder.readLong().toString() Schema.Type.FLOAT -> binaryDecoder.readFloat().toString() Schema.Type.DOUBLE -> binaryDecoder.readDouble().toString() + Schema.Type.ENUM -> currentWriterSchema.enumSymbols[binaryDecoder.readEnum()] else -> throw unsupportedWriterTypeError( Schema.Type.STRING, Schema.Type.BYTES, @@ -197,7 +198,8 @@ internal abstract class AbstractAvroDirectDecoder( Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, - Schema.Type.DOUBLE + Schema.Type.DOUBLE, + Schema.Type.ENUM ) } } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/CollectionsEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/CollectionsEncoder.kt index bf3addf7..5c6b436f 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/CollectionsEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/CollectionsEncoder.kt @@ -6,8 +6,6 @@ import org.apache.avro.Schema internal class MapDirectEncoder(private val schema: Schema, mapSize: Int, avro: Avro, binaryEncoder: org.apache.avro.io.Encoder) : AbstractAvroDirectEncoder(avro, binaryEncoder) { - private var isKey: Boolean = true - init { binaryEncoder.writeMapStart() binaryEncoder.setItemCount(mapSize.toLong()) @@ -28,9 +26,8 @@ internal class MapDirectEncoder(private val schema: Schema, mapSize: Int, avro: index: Int, ): Boolean { super.encodeElement(descriptor, index) - isKey = index % 2 == 0 currentWriterSchema = - if (isKey) { + if (index % 2 == 0) { binaryEncoder.startItem() STRING_SCHEMA } else { diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt index 9caad976..1c79776d 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt @@ -79,13 +79,15 @@ internal class PrimitiveEncodingTest : StringSpec({ testSerializationTypeCompatibility(stringValue, ByteBuffer.wrap(stringValue.encodeToByteArray()), Schema.create(Schema.Type.BYTES)) val fixedStringSchema = Schema.createFixed("fixed", null, null, stringValue.length) testSerializationTypeCompatibility(stringValue, GenericData.Fixed(fixedStringSchema, stringValue.encodeToByteArray()), fixedStringSchema) - + // string can be encoded to boolean, int, long, float, double, enum testSerializationTypeCompatibility("true", true, Schema.create(Schema.Type.BOOLEAN)) testSerializationTypeCompatibility("false", false, Schema.create(Schema.Type.BOOLEAN)) testSerializationTypeCompatibility("23", 23, Schema.create(Schema.Type.INT)) testSerializationTypeCompatibility("55", 55L, Schema.create(Schema.Type.LONG)) testSerializationTypeCompatibility("5.3", 5.3F, Schema.create(Schema.Type.FLOAT)) testSerializationTypeCompatibility("5.3", 5.3, Schema.create(Schema.Type.DOUBLE)) + val enumSchema = Schema.createEnum("enum", null, null, listOf("A", "B")) + testSerializationTypeCompatibility("B", GenericData.EnumSymbol(enumSchema, "B"), enumSchema) listOf( Schema.Type.STRING to String.serializer(), From fc0f0eb99ffc7d787c9ba2cd6c521b953b490b7f Mon Sep 17 00:00:00 2001 From: Chuckame Date: Tue, 4 Feb 2025 00:46:08 +0100 Subject: [PATCH 13/13] chore(benchmark): Update results after refactor MUCH better results in simple writes bench (+26%) --- benchmark/README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/benchmark/README.md b/benchmark/README.md index 98548df5..fc887d92 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -26,23 +26,23 @@ Computer: Macbook air M2 ``` Benchmark Mode Cnt Score Error Units -c.g.a.b.complex.Avro4kBenchmark.read thrpt 5 27482.418 ± 1162.064 ops/s +c.g.a.b.complex.Avro4kBenchmark.read thrpt 5 27070.886 ± 2562.606 ops/s c.g.a.b.complex.ApacheAvroReflectBenchmark.read thrpt 5 26239.615 ± 15033.290 ops/s -c.g.a.b.complex.Avro4kGenericWithApacheAvroBenchmark.read thrpt 5 15862.270 ± 1139.036 ops/s +c.g.a.b.complex.Avro4kGenericWithApacheAvroBenchmark.read thrpt 5 16140.138 ± 90.523 ops/s -c.g.a.b.complex.Avro4kBenchmark.write thrpt 5 54335.043 ± 2481.196 ops/s +c.g.a.b.complex.Avro4kBenchmark.write thrpt 5 54488.317 ± 719.412 ops/s c.g.a.b.complex.ApacheAvroReflectBenchmark.write thrpt 5 47510.885 ± 2467.348 ops/s c.g.a.b.complex.JacksonAvroBenchmark.write thrpt 5 33936.765 ± 2139.528 ops/s -c.g.a.b.complex.Avro4kGenericWithApacheAvroBenchmark.write thrpt 5 27124.366 ± 753.406 ops/s +c.g.a.b.complex.Avro4kGenericWithApacheAvroBenchmark.write thrpt 5 24072.277 ± 697.493 ops/s -c.g.a.b.simple.Avro4kSimpleBenchmark.read thrpt 5 215140.198 ± 9182.259 ops/s +c.g.a.b.simple.Avro4kSimpleBenchmark.read thrpt 5 221277.895 ± 3945.928 ops/s c.g.a.b.simple.ApacheAvroReflectSimpleBenchmark.read thrpt 5 230744.377 ± 30164.628 ops/s -c.g.a.b.simple.Avro4kGenericWithApacheAvroSimpleBenchmark.read thrpt 5 136913.851 ± 8302.833 ops/s +c.g.a.b.simple.Avro4kGenericWithApacheAvroSimpleBenchmark.read thrpt 5 138394.796 ± 5130.421 ops/s c.g.a.b.simple.JacksonAvroSimpleBenchmark.read thrpt 5 69615.099 ± 4047.717 ops/s -c.g.a.b.simple.Avro4kSimpleBenchmark.write thrpt 5 354497.179 ± 8342.002 ops/s +c.g.a.b.simple.Avro4kSimpleBenchmark.write thrpt 5 446673.090 ± 15520.264 ops/s c.g.a.b.simple.ApacheAvroReflectSimpleBenchmark.write thrpt 5 320367.673 ± 33394.537 ops/s -c.g.a.b.simple.Avro4kGenericWithApacheAvroSimpleBenchmark.write thrpt 5 142525.233 ± 2796.318 ops/s +c.g.a.b.simple.Avro4kGenericWithApacheAvroSimpleBenchmark.write thrpt 5 168702.542 ± 5553.797 ops/s c.g.a.b.simple.JacksonAvroSimpleBenchmark.write thrpt 5 138898.312 ± 9156.715 ops/s ```