Skip to content

Commit

Permalink
fix: do not coerce between int/long and decimals, also improve enum a…
Browse files Browse the repository at this point in the history
…nd scalar tests
  • Loading branch information
Chuckame committed Oct 9, 2024
1 parent 4bbcad5 commit 65fbfca
Show file tree
Hide file tree
Showing 9 changed files with 386 additions and 326 deletions.
1 change: 1 addition & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies {
testImplementation(libs.kotest.core)
testImplementation(libs.kotest.json)
testImplementation(libs.kotest.property)
testImplementation(kotlin("reflect"))
}

kotlin {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,6 @@ package com.github.avrokotlin.avro4k.internal
import kotlinx.serialization.SerializationException
import java.math.BigDecimal

internal fun BigDecimal.toLongExact(): Long {
if (this.toLong().toBigDecimal() != this) {
throw SerializationException("Value $this is not a valid Long")
}
return this.toLong()
}

internal fun Int.toByteExact(): Byte {
if (this.toByte().toInt() != this) {
throw SerializationException("Value $this is not a valid Byte")
Expand Down Expand Up @@ -54,35 +47,44 @@ internal fun BigDecimal.toShortExact(): Short {

internal fun Long.toIntExact(): Int {
if (this.toInt().toLong() != this) {
throw SerializationException("Value $this is not a valid Int")
throw invalidType<Int>()
}
return this.toInt()
}

internal fun BigDecimal.toIntExact(): Int {
if (this.toInt().toBigDecimal() != this) {
throw SerializationException("Value $this is not a valid Int")
throw invalidType<Int>()
}
return this.toInt()
}

internal fun BigDecimal.toFloatExact(): Float {
if (this.toFloat().toBigDecimal() != this) {
throw SerializationException("Value $this is not a valid Float")
internal fun BigDecimal.toLongExact(): Long {
if (this.toLong().toBigDecimal() != this) {
throw invalidType<Long>()
}
return this.toFloat()
return this.toLong()
}

internal fun Double.toFloatExact(): Float {
if (this.toFloat().toDouble() != this) {
throw SerializationException("Value $this is not a valid Float")
throw invalidType<Float>()
}
return this.toFloat()
}

internal fun BigDecimal.toDoubleExact(): Double {
if (this.toDouble().toBigDecimal() != this) {
throw SerializationException("Value $this is not a valid Double")
throw invalidType<Double>()
}
return this.toDouble()
}
}

internal fun BigDecimal.toFloatExact(): Float {
if (this.toFloat().toBigDecimal() != this) {
throw invalidType<Float>()
}
return this.toFloat()
}

private inline fun <reified T> Any.invalidType() = SerializationException("Value $this is not a valid ${T::class.simpleName}")
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.github.avrokotlin.avro4k.internal.SerializerLocatorMiddleware
import com.github.avrokotlin.avro4k.internal.decoder.AbstractPolymorphicDecoder
import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch
import com.github.avrokotlin.avro4k.internal.toByteExact
import com.github.avrokotlin.avro4k.internal.toFloatExact
import com.github.avrokotlin.avro4k.internal.toIntExact
import com.github.avrokotlin.avro4k.internal.toShortExact
import com.github.avrokotlin.avro4k.unsupportedWriterTypeError
Expand Down Expand Up @@ -137,8 +138,8 @@ internal abstract class AbstractAvroDirectDecoder(
decodeAndResolveUnion()

return when (currentWriterSchema.type) {
Schema.Type.INT -> binaryDecoder.readInt().toLong()
Schema.Type.LONG -> binaryDecoder.readLong()
Schema.Type.INT -> binaryDecoder.readInt().toLong()
Schema.Type.STRING -> binaryDecoder.readString().toLong()
else -> throw unsupportedWriterTypeError(Schema.Type.LONG, Schema.Type.INT, Schema.Type.STRING)
}
Expand All @@ -148,24 +149,21 @@ internal abstract class AbstractAvroDirectDecoder(
decodeAndResolveUnion()

return when (currentWriterSchema.type) {
Schema.Type.INT -> binaryDecoder.readInt().toFloat()
Schema.Type.LONG -> binaryDecoder.readLong().toFloat()
Schema.Type.FLOAT -> binaryDecoder.readFloat()
Schema.Type.DOUBLE -> binaryDecoder.readDouble().toFloatExact()
Schema.Type.STRING -> binaryDecoder.readString().toFloat()
else -> throw unsupportedWriterTypeError(Schema.Type.FLOAT, Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING)
else -> throw unsupportedWriterTypeError(Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING)
}
}

override fun decodeDouble(): Double {
decodeAndResolveUnion()

return when (currentWriterSchema.type) {
Schema.Type.INT -> binaryDecoder.readInt().toDouble()
Schema.Type.LONG -> binaryDecoder.readLong().toDouble()
Schema.Type.FLOAT -> binaryDecoder.readFloat().toDouble()
Schema.Type.DOUBLE -> binaryDecoder.readDouble()
Schema.Type.STRING -> binaryDecoder.readString().toDouble()
else -> throw unsupportedWriterTypeError(Schema.Type.DOUBLE, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.STRING)
else -> throw unsupportedWriterTypeError(Schema.Type.DOUBLE, Schema.Type.FLOAT, Schema.Type.STRING)
}
}

Expand All @@ -183,10 +181,24 @@ internal abstract class AbstractAvroDirectDecoder(
decodeAndResolveUnion()

return when (currentWriterSchema.type) {
Schema.Type.STRING -> binaryDecoder.readString(null).toString()
Schema.Type.STRING -> binaryDecoder.readString()
Schema.Type.BYTES -> binaryDecoder.readBytes().decodeToString()
Schema.Type.FIXED -> binaryDecoder.readFixedBytes(currentWriterSchema.fixedSize).decodeToString()
else -> throw unsupportedWriterTypeError(Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED)
Schema.Type.BOOLEAN -> binaryDecoder.readBoolean().toString()
Schema.Type.INT -> binaryDecoder.readInt().toString()
Schema.Type.LONG -> binaryDecoder.readLong().toString()
Schema.Type.FLOAT -> binaryDecoder.readFloat().toString()
Schema.Type.DOUBLE -> binaryDecoder.readDouble().toString()
else -> throw unsupportedWriterTypeError(
Schema.Type.STRING,
Schema.Type.BYTES,
Schema.Type.FIXED,
Schema.Type.BOOLEAN,
Schema.Type.INT,
Schema.Type.LONG,
Schema.Type.FLOAT,
Schema.Type.DOUBLE
)
}
}

Expand Down Expand Up @@ -239,7 +251,8 @@ internal abstract class AbstractAvroDirectDecoder(
return when (currentWriterSchema.type) {
Schema.Type.BYTES -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readBytes())
Schema.Type.FIXED -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readFixedBytes(currentWriterSchema.fixedSize))
else -> throw unsupportedWriterTypeError(Schema.Type.BYTES, Schema.Type.FIXED)
Schema.Type.STRING -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readString(null).bytes)
else -> throw unsupportedWriterTypeError(Schema.Type.FIXED, Schema.Type.BYTES, Schema.Type.STRING)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import com.github.avrokotlin.avro4k.internal.SerializerLocatorMiddleware
import com.github.avrokotlin.avro4k.internal.aliases
import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch
import com.github.avrokotlin.avro4k.internal.nonNullSerialName
import com.github.avrokotlin.avro4k.internal.toFloatExact
import com.github.avrokotlin.avro4k.internal.toIntExact
import com.github.avrokotlin.avro4k.namedSchemaNotFoundInUnionError
import com.github.avrokotlin.avro4k.trySelectEnumSchemaForSymbol
import com.github.avrokotlin.avro4k.trySelectFixedSchemaForSize
Expand Down Expand Up @@ -230,10 +232,8 @@ internal abstract class AbstractAvroEncoder : AbstractEncoder(), AvroEncoder {

override fun encodeBoolean(value: Boolean) {
if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) {
trySelectTypeFromUnion(
Schema.Type.BOOLEAN,
Schema.Type.STRING
) || throw typeNotFoundInUnionError(Schema.Type.BOOLEAN, Schema.Type.STRING)
trySelectTypeFromUnion(Schema.Type.BOOLEAN, Schema.Type.STRING) ||
throw typeNotFoundInUnionError(Schema.Type.BOOLEAN, Schema.Type.STRING)
}
when (currentWriterSchema.type) {
Schema.Type.BOOLEAN -> encodeBooleanUnchecked(value)
Expand All @@ -252,58 +252,27 @@ internal abstract class AbstractAvroEncoder : AbstractEncoder(), AvroEncoder {

override fun encodeInt(value: Int) {
if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) {
trySelectTypeFromUnion(
Schema.Type.INT,
Schema.Type.LONG,
Schema.Type.FLOAT,
Schema.Type.DOUBLE,
Schema.Type.STRING
) ||
throw typeNotFoundInUnionError(
Schema.Type.INT,
Schema.Type.LONG,
Schema.Type.FLOAT,
Schema.Type.DOUBLE,
Schema.Type.STRING
)
trySelectTypeFromUnion(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) ||
throw typeNotFoundInUnionError(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING)
}
when (currentWriterSchema.type) {
Schema.Type.INT -> encodeIntUnchecked(value)
Schema.Type.LONG -> encodeLongUnchecked(value.toLong())
Schema.Type.FLOAT -> encodeFloatUnchecked(value.toFloat())
Schema.Type.DOUBLE -> encodeDoubleUnchecked(value.toDouble())
Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString()))
else -> throw unsupportedWriterTypeError(
Schema.Type.INT,
Schema.Type.LONG,
Schema.Type.FLOAT,
Schema.Type.DOUBLE,
Schema.Type.STRING
)
else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING)
}
}

override fun encodeLong(value: Long) {
if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) {
trySelectTypeFromUnion(Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) ||
throw typeNotFoundInUnionError(
Schema.Type.LONG,
Schema.Type.FLOAT,
Schema.Type.DOUBLE,
Schema.Type.STRING
)
trySelectTypeFromUnion(Schema.Type.LONG, Schema.Type.INT, Schema.Type.STRING) ||
throw typeNotFoundInUnionError(Schema.Type.LONG, Schema.Type.INT, Schema.Type.STRING)
}
when (currentWriterSchema.type) {
Schema.Type.INT -> encodeIntUnchecked(value.toIntExact())
Schema.Type.LONG -> encodeLongUnchecked(value)
Schema.Type.FLOAT -> encodeFloatUnchecked(value.toFloat())
Schema.Type.DOUBLE -> encodeDoubleUnchecked(value.toDouble())
Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString()))
else -> throw unsupportedWriterTypeError(
Schema.Type.LONG,
Schema.Type.FLOAT,
Schema.Type.DOUBLE,
Schema.Type.STRING
)
else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING)
}
}

Expand All @@ -322,13 +291,14 @@ internal abstract class AbstractAvroEncoder : AbstractEncoder(), AvroEncoder {

override fun encodeDouble(value: Double) {
if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) {
trySelectTypeFromUnion(Schema.Type.DOUBLE, Schema.Type.STRING) ||
throw typeNotFoundInUnionError(Schema.Type.DOUBLE, Schema.Type.STRING)
trySelectTypeFromUnion(Schema.Type.DOUBLE, Schema.Type.FLOAT, Schema.Type.STRING) ||
throw typeNotFoundInUnionError(Schema.Type.DOUBLE, Schema.Type.FLOAT, Schema.Type.STRING)
}
when (currentWriterSchema.type) {
Schema.Type.FLOAT -> encodeFloatUnchecked(value.toFloatExact())
Schema.Type.DOUBLE -> encodeDoubleUnchecked(value)
Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString()))
else -> throw unsupportedWriterTypeError(Schema.Type.DOUBLE, Schema.Type.STRING)
else -> throw unsupportedWriterTypeError(Schema.Type.DOUBLE, Schema.Type.FLOAT, Schema.Type.STRING)
}
}

Expand All @@ -349,21 +319,41 @@ internal abstract class AbstractAvroEncoder : AbstractEncoder(), AvroEncoder {
trySelectTypeFromUnion(Schema.Type.STRING, Schema.Type.BYTES) ||
trySelectFixedSchemaForSize(value.length) ||
trySelectEnumSchemaForSymbol(value) ||
trySelectTypeFromUnion(
Schema.Type.BOOLEAN,
Schema.Type.INT,
Schema.Type.LONG,
Schema.Type.FLOAT,
Schema.Type.DOUBLE) ||
throw typeNotFoundInUnionError(
Schema.Type.BOOLEAN,
Schema.Type.INT,
Schema.Type.LONG,
Schema.Type.FLOAT,
Schema.Type.DOUBLE,
Schema.Type.STRING,
Schema.Type.BYTES,
Schema.Type.FIXED,
Schema.Type.ENUM
)
Schema.Type.ENUM)
}
when (currentWriterSchema.type) {
Schema.Type.BOOLEAN -> encodeBooleanUnchecked(value.toBooleanStrict())
Schema.Type.INT -> encodeIntUnchecked(value.toInt())
Schema.Type.LONG -> encodeLongUnchecked(value.toLong())
Schema.Type.FLOAT -> encodeFloatUnchecked(value.toFloat())
Schema.Type.DOUBLE -> encodeDoubleUnchecked(value.toDouble())
Schema.Type.STRING -> encodeStringUnchecked(Utf8(value))
Schema.Type.BYTES -> encodeBytesUnchecked(value.encodeToByteArray())
Schema.Type.FIXED -> encodeFixedUnchecked(ensureFixedSize(value.encodeToByteArray()))
Schema.Type.ENUM -> encodeEnumUnchecked(value)
else -> throw unsupportedWriterTypeError(
Schema.Type.BYTES,
Schema.Type.BOOLEAN,
Schema.Type.INT,
Schema.Type.LONG,
Schema.Type.FLOAT,
Schema.Type.DOUBLE,
Schema.Type.STRING,
Schema.Type.BYTES,
Schema.Type.FIXED,
Schema.Type.ENUM
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ internal class MapVisitor(
if (it.isNullable()) {
throw AvroSchemaGenerationException("Map key cannot be nullable. Actual generated map key schema: $it")
}
if (!it.isStringable()) {
throw AvroSchemaGenerationException("Map key must be string-able (boolean, number, enum, or string). Actual generated map key schema: $it")
if (!it.isNonNullScalarType()) {
throw AvroSchemaGenerationException("Map key must be a non-null scalar type (e.g. not a record, map or array). Actual generated map key schema: $it")
}
}

Expand All @@ -40,7 +40,7 @@ internal class MapVisitor(
}
}

private fun Schema.isStringable(): Boolean =
internal fun Schema.isNonNullScalarType(): Boolean =
when (type) {
Schema.Type.BOOLEAN,
Schema.Type.INT,
Expand All @@ -49,18 +49,16 @@ private fun Schema.isStringable(): Boolean =
Schema.Type.DOUBLE,
Schema.Type.STRING,
Schema.Type.ENUM,
Schema.Type.BYTES,
Schema.Type.FIXED,
-> true

Schema.Type.NULL,
// bytes could be stringified, but it's not a good idea as it can produce unreadable strings.
Schema.Type.BYTES,
// same, just bytes. Btw, if the user wants to stringify it, he can use @Contextual or custom @Serializable serializer.
Schema.Type.FIXED,
Schema.Type.ARRAY,
Schema.Type.MAP,
Schema.Type.RECORD,
null,
-> false

Schema.Type.UNION -> types.all { it.isStringable() }
Schema.Type.UNION -> types.all { it.isNonNullScalarType() }
}
Loading

0 comments on commit 65fbfca

Please sign in to comment.