From 7fa55fade262c69d2da465d142390d13ea6c4736 Mon Sep 17 00:00:00 2001 From: Quin Lynch <49576606+quinchs@users.noreply.github.com> Date: Tue, 26 Sep 2023 13:28:29 -0300 Subject: [PATCH] Fix primitive deserialization on datatypes (#20) --- .../edgedb/driver/binary/PacketWriter.java | 5 +- .../driver/binary/builders/ObjectBuilder.java | 26 ++++++++-- .../builders/types/TypeDeserializerInfo.java | 12 ++++- .../driver/binary/codecs/ObjectCodec.java | 2 +- .../exceptions/NoTypeConverterException.java | 9 ++++ .../driver/util/BinaryProtocolUtils.java | 31 ----------- .../com/edgedb/driver/util/TypeUtils.java | 51 +++++++++++++++++++ src/driver/src/test/java/QueryTests.java | 33 ++++++++++++ 8 files changed, 130 insertions(+), 39 deletions(-) create mode 100644 src/driver/src/test/java/QueryTests.java diff --git a/src/driver/src/main/java/com/edgedb/driver/binary/PacketWriter.java b/src/driver/src/main/java/com/edgedb/driver/binary/PacketWriter.java index ac4c344..2f0bce7 100644 --- a/src/driver/src/main/java/com/edgedb/driver/binary/PacketWriter.java +++ b/src/driver/src/main/java/com/edgedb/driver/binary/PacketWriter.java @@ -2,6 +2,7 @@ import com.edgedb.driver.exceptions.EdgeDBException; import com.edgedb.driver.util.BinaryProtocolUtils; +import com.edgedb.driver.util.TypeUtils; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import org.jetbrains.annotations.NotNull; @@ -209,7 +210,7 @@ public void write(@NotNull T serializable) throws O public void writeArray(T @NotNull [] serializableArray, @NotNull Class lengthPrimitive) throws OperationNotSupportedException { ensureCanWrite(BinaryProtocolUtils.sizeOf(serializableArray, lengthPrimitive)); - var len = BinaryProtocolUtils.castNumber(serializableArray.length, lengthPrimitive); + var len = TypeUtils.castToPrimitiveNumber(serializableArray.length, lengthPrimitive); primitiveNumberWriters.get(lengthPrimitive).write(this, len); @@ -225,7 +226,7 @@ public & BinaryEnum> void writeEnumSet(@ flags |= v.getValue().longValue(); } - primitiveNumberWriters.get(primitive).write(this, BinaryProtocolUtils.castNumber(flags, primitive)); + primitiveNumberWriters.get(primitive).write(this, TypeUtils.castToPrimitiveNumber(flags, primitive)); } @FunctionalInterface diff --git a/src/driver/src/main/java/com/edgedb/driver/binary/builders/ObjectBuilder.java b/src/driver/src/main/java/com/edgedb/driver/binary/builders/ObjectBuilder.java index cc97c45..d7f6623 100644 --- a/src/driver/src/main/java/com/edgedb/driver/binary/builders/ObjectBuilder.java +++ b/src/driver/src/main/java/com/edgedb/driver/binary/builders/ObjectBuilder.java @@ -7,6 +7,8 @@ import com.edgedb.driver.binary.packets.receivable.Data; import com.edgedb.driver.clients.EdgeDBBinaryClient; import com.edgedb.driver.exceptions.EdgeDBException; +import com.edgedb.driver.exceptions.NoTypeConverterException; +import com.edgedb.driver.util.TypeUtils; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -15,7 +17,6 @@ import java.util.*; public final class ObjectBuilder { - @FunctionalInterface public interface CollectionConverter> { T convert(Object[] value); @@ -65,9 +66,26 @@ public interface CollectionConverter> { return convertCollection(cls, value); } - return (T)value; - } - catch (Exception x) { + if(valueType.isPrimitive() && TypeUtils.PRIMITIVE_REFERENCE_MAP.get(cls) == valueType) { + return (T) value; // JVM handles the underlying conversions of primitives + } + + if( + cls.isPrimitive() && + TypeUtils.PRIMITIVE_REFERENCE_MAP.get(valueType) == cls + ) { + return (T) value; // JVM handles the underlying conversions of primitives + } + + try { + return cls.cast(value); + } catch (Exception err) { + throw new NoTypeConverterException( + String.format("Cannot use the type %s to represent the value %s", cls.getName(), valueType.getName()), + err + ); + } + } catch (Exception x) { throw new EdgeDBException("Failed to convert type to specified result", x); } } diff --git a/src/driver/src/main/java/com/edgedb/driver/binary/builders/types/TypeDeserializerInfo.java b/src/driver/src/main/java/com/edgedb/driver/binary/builders/types/TypeDeserializerInfo.java index f66365b..8381542 100644 --- a/src/driver/src/main/java/com/edgedb/driver/binary/builders/types/TypeDeserializerInfo.java +++ b/src/driver/src/main/java/com/edgedb/driver/binary/builders/types/TypeDeserializerInfo.java @@ -7,6 +7,7 @@ import com.edgedb.driver.binary.builders.internal.ObjectEnumeratorImpl; import com.edgedb.driver.binary.packets.shared.Cardinality; import com.edgedb.driver.exceptions.EdgeDBException; +import com.edgedb.driver.exceptions.NoTypeConverterException; import com.edgedb.driver.namingstrategies.NamingStrategy; import com.edgedb.driver.util.FastInverseIndexer; import com.edgedb.driver.util.StringsUtil; @@ -398,7 +399,16 @@ private Class extractCollectionInnerType(@NotNull Class cls) throws EdgeDB } public void convertAndSet(boolean useMethodSetter, Object instance, Object value) throws EdgeDBException, ReflectiveOperationException { - var converted = convertToType(value); + Object converted; + + try { + converted = convertToType(value); + } catch (EdgeDBException error) { + var valueType = value == null ? "NULL" : value.getClass().getName(); + throw new NoTypeConverterException( + String.format("The field '%s' with type '%s' cannot be implicitly assigned to the received data type '%s'", field.getName(), field.getType().getName(), valueType) + ); + } if(useMethodSetter && setMethod != null) { setMethod.invoke(instance, converted); diff --git a/src/driver/src/main/java/com/edgedb/driver/binary/codecs/ObjectCodec.java b/src/driver/src/main/java/com/edgedb/driver/binary/codecs/ObjectCodec.java index 284e3cc..e2875db 100644 --- a/src/driver/src/main/java/com/edgedb/driver/binary/codecs/ObjectCodec.java +++ b/src/driver/src/main/java/com/edgedb/driver/binary/codecs/ObjectCodec.java @@ -56,7 +56,7 @@ public TypeInitializedObjectCodec(@NotNull TypeDeserializerInfo info, @NotNul try { return deserializer.factory.deserialize(enumerator); } catch (Exception x) { - throw new EdgeDBException("Failed to deserialize object to " + getConvertingClass().getName(), x); + throw new EdgeDBException("Failed to deserialize " + target.getName(), x); } } diff --git a/src/driver/src/main/java/com/edgedb/driver/exceptions/NoTypeConverterException.java b/src/driver/src/main/java/com/edgedb/driver/exceptions/NoTypeConverterException.java index 604663d..00a2a4c 100644 --- a/src/driver/src/main/java/com/edgedb/driver/exceptions/NoTypeConverterException.java +++ b/src/driver/src/main/java/com/edgedb/driver/exceptions/NoTypeConverterException.java @@ -22,4 +22,13 @@ public NoTypeConverterException(@NotNull Class target) { public NoTypeConverterException(String message) { super(message); } + + /** + * Constructs a new {@linkplain NoTypeConverterException} + * @param message The detailed message describing why the exception was thrown. + * @param inner The inner cause of this exception. + */ + public NoTypeConverterException(String message, Exception inner) { + super(message, inner); + } } diff --git a/src/driver/src/main/java/com/edgedb/driver/util/BinaryProtocolUtils.java b/src/driver/src/main/java/com/edgedb/driver/util/BinaryProtocolUtils.java index acefe72..89781e9 100644 --- a/src/driver/src/main/java/com/edgedb/driver/util/BinaryProtocolUtils.java +++ b/src/driver/src/main/java/com/edgedb/driver/util/BinaryProtocolUtils.java @@ -4,15 +4,9 @@ import io.netty.buffer.ByteBuf; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; -import org.joou.*; import java.nio.charset.StandardCharsets; import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; - -import static org.joou.Unsigned.*; public class BinaryProtocolUtils { public static final int DOUBLE_SIZE = 8; @@ -25,25 +19,6 @@ public class BinaryProtocolUtils { public static final int BOOL_SIZE = 1; public static final int UUID_SIZE = 16; - private static final @NotNull Map, Function> numberCastMap; - - static { - numberCastMap = new HashMap<>() { - { - put(Long.TYPE, Number::longValue); - put(Integer.TYPE, Number::intValue); - put(Short.TYPE, Number::shortValue); - put(Byte.TYPE, Number::byteValue); - put(Double.TYPE, Number::doubleValue); - put(Float.TYPE, Number::floatValue); - put(UByte.class, number -> ubyte(number.longValue())); - put(UShort.class, number -> ushort(number.intValue())); - put(UInteger.class, number -> uint(number.longValue())); - put(ULong.class, number -> ulong(number.longValue())); - } - }; - } - public static int sizeOf(@Nullable String s) { int size = 4; @@ -54,7 +29,6 @@ public static int sizeOf(@Nullable String s) { return size; } - public static int sizeOf(@Nullable ByteBuf buffer) { int size = 4; @@ -88,11 +62,6 @@ else if(primitive == Float.TYPE || primitive == Float.class) { throw new ArithmeticException("Unable to determine the size of " + primitive.getName()); } - @SuppressWarnings("unchecked") - public static U castNumber(T value, Class target) { - return (U) numberCastMap.get(target).apply(value); - } - public static int sizeOf(T @NotNull [] arr, @NotNull Class primitive) { return Arrays.stream(arr).mapToInt(T::getSize).sum() + sizeOf(primitive); } diff --git a/src/driver/src/main/java/com/edgedb/driver/util/TypeUtils.java b/src/driver/src/main/java/com/edgedb/driver/util/TypeUtils.java index dc11c8f..bcfee01 100644 --- a/src/driver/src/main/java/com/edgedb/driver/util/TypeUtils.java +++ b/src/driver/src/main/java/com/edgedb/driver/util/TypeUtils.java @@ -2,10 +2,61 @@ import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; +import org.joou.UByte; +import org.joou.UInteger; +import org.joou.ULong; +import org.joou.UShort; import java.lang.reflect.Array; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +import static org.joou.Unsigned.*; +import static org.joou.Unsigned.ulong; public class TypeUtils { + public static final Map, Class> PRIMITIVE_REFERENCE_MAP = new HashMap<>(){{ + put(Byte.class, Byte.TYPE); + put(Short.class, Short.TYPE); + put(Integer.class, Integer.TYPE); + put(Long.class, Long.TYPE); + put(Float.class, Float.TYPE); + put(Double.class, Double.TYPE); + put(Character.class, Character.TYPE); + put(Boolean.class, Boolean.TYPE); + put(Byte.TYPE, Byte.class); + put(Short.TYPE, Short.class); + put(Integer.TYPE, Integer.class); + put(Long.TYPE, Long.class); + put(Float.TYPE, Float.class); + put(Double.TYPE, Double.class); + put(Character.TYPE, Character.class); + put(Boolean.TYPE, Boolean.class); + }}; + private static final @NotNull Map, Function> PRIMITIVE_NUMBER_CAST_MAP; + + static { + PRIMITIVE_NUMBER_CAST_MAP = new HashMap<>() { + { + put(Long.TYPE, Number::longValue); + put(Integer.TYPE, Number::intValue); + put(Short.TYPE, Number::shortValue); + put(Byte.TYPE, Number::byteValue); + put(Double.TYPE, Number::doubleValue); + put(Float.TYPE, Number::floatValue); + put(UByte.class, number -> ubyte(number.longValue())); + put(UShort.class, number -> ushort(number.intValue())); + put(UInteger.class, number -> uint(number.longValue())); + put(ULong.class, number -> ulong(number.longValue())); + } + }; + } + + @SuppressWarnings("unchecked") + public static U castToPrimitiveNumber(T value, Class target) { + return (U) PRIMITIVE_NUMBER_CAST_MAP.get(target).apply(value); + } public static Object getDefaultValue(Class cls) { return Array.get(Array.newInstance(cls, 1), 0); diff --git a/src/driver/src/test/java/QueryTests.java b/src/driver/src/test/java/QueryTests.java new file mode 100644 index 0000000..d7ca6bd --- /dev/null +++ b/src/driver/src/test/java/QueryTests.java @@ -0,0 +1,33 @@ +import com.edgedb.driver.EdgeDBClient; +import com.edgedb.driver.annotations.EdgeDBType; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +public class QueryTests { + + @EdgeDBType + public static class TestDataContainer { + public long a; + public Long b; + public int c; + public Integer d; + } + + @Test + public void TestPrimitives() { + // primitives (long, int, etc.) differ from the class form (Long, Integer, etc.), + // we test that we can deserialize both in a data structure. + try(var client = new EdgeDBClient()) { + var result = client.queryRequiredSingle(TestDataContainer.class, "select { a := 1, b := 2, c := 3, d := 4}") + .toCompletableFuture().get(); + + assertThat(result.a).isEqualTo(1); + assertThat(result.b).isEqualTo(2); + assertThat(result.c).isEqualTo(3); + assertThat(result.d).isEqualTo(4); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +}