diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt index 31f1b6f..3b055ee 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt @@ -184,8 +184,8 @@ internal abstract class AbstractAvroDirectDecoder( return when (currentWriterSchema.type) { Schema.Type.STRING -> binaryDecoder.readString(null).toString() - Schema.Type.BYTES -> binaryDecoder.readBytes(null).array().decodeToString() - Schema.Type.FIXED -> ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }.decodeToString() + Schema.Type.BYTES -> binaryDecoder.readBytes().decodeToString() + Schema.Type.FIXED -> binaryDecoder.readFixedBytes(currentWriterSchema.fixedSize).decodeToString() else -> throw unsupportedWriterTypeError(Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED) } } @@ -194,41 +194,40 @@ internal abstract class AbstractAvroDirectDecoder( decodeAndResolveUnion() return when (currentWriterSchema.type) { - Schema.Type.ENUM -> + Schema.Type.ENUM -> { if (currentWriterSchema.isFullNameOrAliasMatch(enumDescriptor)) { - val enumName = currentWriterSchema.enumSymbols[binaryDecoder.readEnum()] - val idx = enumDescriptor.getElementIndex(enumName) - if (idx >= 0) { - idx - } else { - avro.enumResolver.getDefaultValueIndex(enumDescriptor) - ?: throw SerializationException("Unknown enum symbol name '$enumName' for Enum '${enumDescriptor.serialName}' for writer schema $currentWriterSchema") - } + val enumSymbol = currentWriterSchema.enumSymbols[binaryDecoder.readEnum()] + enumDescriptor.getEnumIndex(enumSymbol) } else { throw unsupportedWriterTypeError(Schema.Type.ENUM, Schema.Type.STRING) } + } Schema.Type.STRING -> { val enumSymbol = binaryDecoder.readString() - val idx = enumDescriptor.getElementIndex(enumSymbol) - if (idx >= 0) { - idx - } else { - avro.enumResolver.getDefaultValueIndex(enumDescriptor) - ?: throw SerializationException("Unknown enum symbol '$enumSymbol' for Enum '${enumDescriptor.serialName}' for writer schema $currentWriterSchema") - } + enumDescriptor.getEnumIndex(enumSymbol) } else -> throw unsupportedWriterTypeError(Schema.Type.ENUM, Schema.Type.STRING) } } + private fun SerialDescriptor.getEnumIndex(enumName: String): Int { + val idx = getElementIndex(enumName) + return if (idx >= 0) { + idx + } else { + avro.enumResolver.getDefaultValueIndex(this) + ?: throw SerializationException("Unknown enum symbol name '$enumName' for Enum '${this.serialName}' for writer schema $currentWriterSchema") + } + } + override fun decodeBytes(): ByteArray { decodeAndResolveUnion() return when (currentWriterSchema.type) { - Schema.Type.BYTES -> binaryDecoder.readBytes(null).array() - Schema.Type.FIXED -> ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) } + Schema.Type.BYTES -> binaryDecoder.readBytes() + Schema.Type.FIXED -> binaryDecoder.readFixedBytes(currentWriterSchema.fixedSize) Schema.Type.STRING -> binaryDecoder.readString(null).bytes else -> throw unsupportedWriterTypeError(Schema.Type.BYTES, Schema.Type.FIXED, Schema.Type.STRING) } @@ -238,13 +237,21 @@ internal abstract class AbstractAvroDirectDecoder( decodeAndResolveUnion() return when (currentWriterSchema.type) { - Schema.Type.BYTES -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readBytes(null).array()) - Schema.Type.FIXED -> GenericData.Fixed(currentWriterSchema, ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }) + Schema.Type.BYTES -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readBytes()) + Schema.Type.FIXED -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readFixedBytes(currentWriterSchema.fixedSize)) else -> throw unsupportedWriterTypeError(Schema.Type.BYTES, Schema.Type.FIXED) } } } +private fun org.apache.avro.io.Decoder.readFixedBytes(size: Int): ByteArray { + return ByteArray(size).also { buf -> readFixed(buf) } +} + +private fun org.apache.avro.io.Decoder.readBytes(): ByteArray { + return readBytes(null).array() +} + private class PolymorphicDecoder( avro: Avro, descriptor: SerialDescriptor, diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt index 15f47e7..0648ac3 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt @@ -1,15 +1,20 @@ package com.github.avrokotlin.avro4k +import com.github.avrokotlin.avro4k.internal.nullable import io.kotest.assertions.Actual import io.kotest.assertions.Expected import io.kotest.assertions.failure import io.kotest.assertions.print.Printed import io.kotest.assertions.withClue +import io.kotest.core.spec.style.scopes.StringSpecRootScope import io.kotest.matchers.shouldBe import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable import kotlinx.serialization.serializer import org.apache.avro.Conversions import org.apache.avro.Schema +import org.apache.avro.SchemaBuilder import org.apache.avro.generic.GenericData import org.apache.avro.generic.GenericDatumReader import org.apache.avro.generic.GenericDatumWriter @@ -202,4 +207,84 @@ internal object AvroAssertions { ): AvroEncodingAssertions { return AvroEncodingAssertions(value, serializer as KSerializer) } +} + +fun encodeToBytesUsingApacheLib( + schema: Schema, + toEncode: Any?, +): ByteArray { + return ByteArrayOutputStream().use { + GenericData.get().createDatumWriter(schema).write(toEncode, EncoderFactory.get().directBinaryEncoder(it, null)) + it.toByteArray() + } +} + +internal inline fun StringSpecRootScope.basicScalarEncodeDecodeTests(value: T, schema: Schema, apacheCompatibleValue: Any? = value) { + "support scalar type ${schema.type} serialization" { + testEncodeDecode(schema, value, apacheCompatibleValue = apacheCompatibleValue) + testEncodeDecode(schema, TestGenericValueClass(value), apacheCompatibleValue = apacheCompatibleValue) + + testEncodeDecode(schema.nullable, value, apacheCompatibleValue = apacheCompatibleValue) + testEncodeDecode(schema.nullable, null) + + testEncodeDecode(schema.nullable, TestGenericValueClass(value), apacheCompatibleValue = apacheCompatibleValue) + testEncodeDecode(schema.nullable, TestGenericValueClass(null), apacheCompatibleValue = null) + testEncodeDecode?>(schema.nullable, null) + } + "scalar type ${schema.type} in record" { + val record = + SchemaBuilder.record("theRecord").fields() + .name("field").type(schema).noDefault() + .endRecord() + + testEncodeDecode(record, TestGenericRecord(value), apacheCompatibleValue = GenericData.Record(record).also { it.put(0, apacheCompatibleValue) }) + testEncodeDecode(record, TestGenericRecord(TestGenericValueClass(value)), apacheCompatibleValue = GenericData.Record(record).also { it.put(0, apacheCompatibleValue) }) + + val recordNullable = + SchemaBuilder.record("theRecord").fields() + .name("field").type(schema.nullable).noDefault() + .endRecord() + testEncodeDecode(recordNullable, TestGenericRecord(value), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, apacheCompatibleValue) }) + testEncodeDecode(recordNullable, TestGenericRecord(null), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, null) }) + testEncodeDecode(recordNullable, TestGenericRecord(TestGenericValueClass(value)), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, apacheCompatibleValue) }) + testEncodeDecode(recordNullable, TestGenericRecord(TestGenericValueClass(null)), apacheCompatibleValue = GenericData.Record(recordNullable).also { it.put(0, null) }) + } + "scalar type ${schema.type} in map" { + val map = SchemaBuilder.map().values(schema) + testEncodeDecode(map, mapOf("key" to value), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) + testEncodeDecode(map, mapOf("key" to TestGenericValueClass(value)), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) + + val mapNullable = SchemaBuilder.map().values(schema.nullable) + testEncodeDecode(mapNullable, mapOf("key" to TestGenericValueClass(value)), apacheCompatibleValue = mapOf("key" to apacheCompatibleValue)) + testEncodeDecode(mapNullable, mapOf("key" to TestGenericValueClass(null)), apacheCompatibleValue = mapOf("key" to null)) + } + "scalar type ${schema.type} in array" { + val array = SchemaBuilder.array().items(schema) + testEncodeDecode(array, listOf(value), apacheCompatibleValue = listOf(apacheCompatibleValue)) + testEncodeDecode(array, listOf(TestGenericValueClass(value)), apacheCompatibleValue = listOf(apacheCompatibleValue)) + + val arrayNullable = SchemaBuilder.array().items(schema.nullable) + testEncodeDecode(arrayNullable, listOf(TestGenericValueClass(value)), apacheCompatibleValue = listOf(apacheCompatibleValue)) + testEncodeDecode(arrayNullable, listOf(TestGenericValueClass(null)), apacheCompatibleValue = listOf(null)) + } +} + +@Serializable +@SerialName("theRecord") +internal data class TestGenericRecord(val field: T) + +@JvmInline +@Serializable +internal value class TestGenericValueClass(val value: T) + +inline fun testEncodeDecode( + schema: Schema, + toEncode: T, + decoded: Any? = toEncode, + apacheCompatibleValue: Any? = toEncode, + serializer: KSerializer = Avro.serializersModule.serializer(), + expectedBytes: ByteArray = encodeToBytesUsingApacheLib(schema, apacheCompatibleValue), +) { + Avro.encodeToByteArray(schema, serializer, toEncode) shouldBe expectedBytes + Avro.decodeFromByteArray(schema, serializer, expectedBytes) shouldBe decoded } \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt index a42ad64..4858dff 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt @@ -5,55 +5,26 @@ package com.github.avrokotlin.avro4k.encoding import com.github.avrokotlin.avro4k.Avro import com.github.avrokotlin.avro4k.AvroAssertions import com.github.avrokotlin.avro4k.AvroEnumDefault +import com.github.avrokotlin.avro4k.basicScalarEncodeDecodeTests +import com.github.avrokotlin.avro4k.encodeToByteArray +import com.github.avrokotlin.avro4k.encodeToBytesUsingApacheLib import com.github.avrokotlin.avro4k.record import com.github.avrokotlin.avro4k.schema import com.github.avrokotlin.avro4k.serializer.UUIDSerializer +import io.kotest.assertions.throwables.shouldThrow import io.kotest.core.spec.style.StringSpec import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException import kotlinx.serialization.UseSerializers +import org.apache.avro.SchemaBuilder import org.apache.avro.generic.GenericData internal class EnumEncodingTest : StringSpec({ - "read / write enums" { - AvroAssertions.assertThat(EnumTest(Cream.Bruce, BBM.Moore)) - .isEncodedAs(record(GenericData.EnumSymbol(Avro.schema(), "Bruce"), GenericData.EnumSymbol(Avro.schema(), "Moore"))) + basicScalarEncodeDecodeTests(Cream.Bruce, Avro.schema(), apacheCompatibleValue = GenericData.EnumSymbol(Avro.schema(), "Bruce")) - AvroAssertions.assertThat(Cream.Bruce) - .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "Bruce")) - AvroAssertions.assertThat(CreamValueClass(Cream.Bruce)) - .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "Bruce")) - } - - "read / write list of enums" { - AvroAssertions.assertThat(EnumListTest(listOf(Cream.Bruce, Cream.Clapton))) - .isEncodedAs(record(listOf(GenericData.EnumSymbol(Avro.schema(), "Bruce"), GenericData.EnumSymbol(Avro.schema(), "Clapton")))) - - AvroAssertions.assertThat(listOf(Cream.Bruce, Cream.Clapton)) - .isEncodedAs(listOf(GenericData.EnumSymbol(Avro.schema(), "Bruce"), GenericData.EnumSymbol(Avro.schema(), "Clapton"))) - AvroAssertions.assertThat(listOf(CreamValueClass(Cream.Bruce), CreamValueClass(Cream.Clapton))) - .isEncodedAs(listOf(GenericData.EnumSymbol(Avro.schema(), "Bruce"), GenericData.EnumSymbol(Avro.schema(), "Clapton"))) - } - - "read / write nullable enums" { - AvroAssertions.assertThat(NullableEnumTest(null)) - .isEncodedAs(record(null)) - AvroAssertions.assertThat(NullableEnumTest(Cream.Bruce)) - .isEncodedAs(record(GenericData.EnumSymbol(Avro.schema(), "Bruce"))) - - AvroAssertions.assertThat(Cream.Bruce) - .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "Bruce")) - AvroAssertions.assertThat(null) - .isEncodedAs(null) - - AvroAssertions.assertThat(CreamValueClass(Cream.Bruce)) - .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "Bruce")) - AvroAssertions.assertThat(null) - .isEncodedAs(null) - } - - "Decoding enum with an unknown uses @AvroEnumDefault value" { + "Decoding enum with an unknown symbol uses @AvroEnumDefault value" { AvroAssertions.assertThat(EnumV2WrapperRecord(EnumV2.B)) .isEncodedAs(record(GenericData.EnumSymbol(Avro.schema(), "B"))) .isDecodedAs(EnumV1WrapperRecord(EnumV1.UNKNOWN)) @@ -62,6 +33,23 @@ internal class EnumEncodingTest : StringSpec({ .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "B")) .isDecodedAs(EnumV1.UNKNOWN) } + + "Decoding enum with an unknown symbol fails without @AvroEnumDefault, also ignoring default symbol in writer schema" { + val schema = SchemaBuilder.enumeration("Enum").defaultSymbol("Z").symbols("X", "Z") + + val bytes = encodeToBytesUsingApacheLib(schema, GenericData.EnumSymbol(schema, "X")) + shouldThrow { + Avro.decodeFromByteArray(schema, EnumV1WithoutDefault.serializer(), bytes) + } + } + + "Encoding enum with an unknown symbol fails even with default in writer schema" { + val schema = SchemaBuilder.enumeration("Enum").defaultSymbol("Z").symbols("X", "Z") + + shouldThrow { + Avro.encodeToByteArray(schema, EnumV1WithoutDefault.A) + } + } }) { @Serializable @SerialName("EnumWrapper") @@ -83,6 +71,13 @@ internal class EnumEncodingTest : StringSpec({ A, } + @Serializable + @SerialName("Enum") + private enum class EnumV1WithoutDefault { + UNKNOWN, + A, + } + @Serializable @SerialName("Enum") private enum class EnumV2 { @@ -93,27 +88,9 @@ internal class EnumEncodingTest : StringSpec({ } @Serializable - private data class EnumTest(val a: Cream, val b: BBM) - - @JvmInline - @Serializable - private value class CreamValueClass(val a: Cream) - - @Serializable - private data class EnumListTest(val a: List) - - @Serializable - private data class NullableEnumTest(val a: Cream?) - private enum class Cream { Bruce, Baker, Clapton, } - - private enum class BBM { - Bruce, - Baker, - Moore, - } } \ No newline at end of file