diff --git a/examples/scala-examples/build.sbt b/examples/scala-examples/build.sbt index 90aa6c3..1b5a53e 100644 --- a/examples/scala-examples/build.sbt +++ b/examples/scala-examples/build.sbt @@ -3,7 +3,7 @@ ThisBuild / version := "0.1.0-SNAPSHOT" ThisBuild / scalaVersion := "3.1.3" libraryDependencies ++= Seq( - "com.edgedb" % "driver" % "0.2.3" from "file:///" + System.getProperty("user.dir") + "/lib/com.edgedb.driver-0.2.3.jar", + "com.edgedb" % "driver" % "0.2.4" from "file:///" + System.getProperty("user.dir") + "/lib/com.edgedb.driver-0.2.4-SNAPSHOT.jar", "ch.qos.logback" % "logback-classic" % "1.4.7", "ch.qos.logback" % "logback-core" % "1.4.7", "com.fasterxml.jackson.core" % "jackson-databind" % "2.15.1", diff --git a/src/driver/src/main/java/com/edgedb/driver/binary/builders/types/TypeDeserializerInfo.java b/src/driver/src/main/java/com/edgedb/driver/binary/builders/types/TypeDeserializerInfo.java index 13d0eca..5a390e4 100644 --- a/src/driver/src/main/java/com/edgedb/driver/binary/builders/types/TypeDeserializerInfo.java +++ b/src/driver/src/main/java/com/edgedb/driver/binary/builders/types/TypeDeserializerInfo.java @@ -179,59 +179,89 @@ private boolean isValidField(@NotNull Field field) { } @SuppressWarnings("unchecked") - private @NotNull TypeDeserializerFactory createFactory() throws ReflectiveOperationException { - // check for constructor deserializer - var constructors = this.type.getDeclaredConstructors(); + private @Nullable TypeDeserializerFactory createFactoryFromBestConstructor(Constructor[] constructors) { + var fields = getFields(); - var ctorDeserializer = Arrays.stream(constructors).filter(x -> x.getAnnotation(EdgeDBDeserializer.class) != null).findFirst(); - - if(ctorDeserializer.isPresent()) { - var ctor = ctorDeserializer.get(); + for(var i = 0; i != constructors.length; i++) { + var ctor = constructors[i]; var ctorParams = ctor.getParameters(); + if(ctorParams.length == 1 && ctorParams[0].getType().equals(ObjectEnumerator.class)) { return (enumerator, parent) -> (T)ctor.newInstance(enumerator); } - return (enumerator, parent) -> { - var namingStrategyEntry = constructorNamingMap.computeIfAbsent( - ((ObjectEnumeratorImpl)enumerator).getClient().getConfig().getNamingStrategy(), - (n) -> new NamingStrategyMap<>(n, (v) -> getNameOrAnnotated(v, Parameter::getName), ctor.getParameters()) - ); + if(fields.size() == ctorParams.length) { + // assert that all types match + var valid = true; + for(var j = 0; j != fields.size(); j++) { + if(!ctorParams[j].getType().isAssignableFrom(fields.get(j).fieldType)) { + valid = false; + break; + } + } - var params = new Object[namingStrategyEntry.nameIndexMap.size()]; - var inverseIndexer = new FastInverseIndexer(params.length); + if(!valid) continue; - ObjectEnumerator.ObjectElement element; + // update each field with annotations from the parameters + for(var j = 0; j != fields.size(); j++) { + fields.get(j).updateFromDeserializerParameter(ctorParams[j]); + } - var unhandled = new Vector(params.length); + return (enumerator, parent) -> { + var namingStrategyEntry = constructorNamingMap.computeIfAbsent( + ((ObjectEnumeratorImpl)enumerator).getClient().getConfig().getNamingStrategy(), + (n) -> new NamingStrategyMap<>(n, (v) -> getNameOrAnnotated(n, v, v, Parameter::getName), ctor.getParameters()) + ); + + var params = new Object[namingStrategyEntry.nameIndexMap.size()]; + var inverseIndexer = new FastInverseIndexer(params.length); + + ObjectEnumerator.ObjectElement element; - while(enumerator.hasRemaining() && (element = enumerator.next()) != null) { - if(namingStrategyEntry.map.containsKey(element.getName())) { - var i = namingStrategyEntry.nameIndexMap.get(element.getName()); - inverseIndexer.set(i); - params[i] = element.getValue(); - } else { - unhandled.add(element); + var unhandled = new Vector(params.length); + + while(enumerator.hasRemaining() && (element = enumerator.next()) != null) { + if(namingStrategyEntry.map.containsKey(element.getName())) { + var namingIndex = namingStrategyEntry.nameIndexMap.get(element.getName()); + var field = fields.get(namingIndex); + inverseIndexer.set(namingIndex); + params[namingIndex] = field.convertToType(element.getValue()); + } else { + unhandled.add(element); + } } - } - var missed = inverseIndexer.getInverseIndexes(); + var missed = inverseIndexer.getInverseIndexes(); - for(int i = 0; i != missed.length; i++) { - params[missed[i]] = TypeUtils.getDefaultValue(namingStrategyEntry.values.get(i).getType()); - } + for(int j = 0; j != missed.length; j++) { + params[missed[j]] = TypeUtils.getDefaultValue(namingStrategyEntry.values.get(j).getType()); + } - var instance = (T)ctor.newInstance(params); + var instance = (T)ctor.newInstance(params); - if(parent != null) { - for (var unhandledElement : unhandled) { - parent.accept(instance, unhandledElement); + if(parent != null) { + for (var unhandledElement : unhandled) { + parent.accept(instance, unhandledElement); + } } - } - return instance; - }; + return instance; + }; + } + } + + return null; + } + + @SuppressWarnings("unchecked") + private @NotNull TypeDeserializerFactory createFactory() throws ReflectiveOperationException { + // check for constructor deserializer + var constructors = this.type.getDeclaredConstructors(); + var deserializer = createFactoryFromBestConstructor(constructors); + + if(deserializer != null) { + return deserializer; } // abstract or interface @@ -246,7 +276,7 @@ private boolean isValidField(@NotNull Field field) { var namingStrategyEntry = fieldNamingMap.computeIfAbsent( ((ObjectEnumeratorImpl)enumerator).getClient().getConfig().getNamingStrategy(), - (v) -> new NamingStrategyMap<>(v, (u) -> getNameOrAnnotated(u.field, Field::getName), getFields()) + (v) -> new NamingStrategyMap<>(v, (u) -> getNameOrAnnotated(v, u, u.field, Field::getName), getFields()) ); var element = enumerator.next(); @@ -293,7 +323,7 @@ private boolean isValidField(@NotNull Field field) { return (enumerator, parent) -> { var namingStrategyEntry = fieldNamingMap.computeIfAbsent( ((ObjectEnumeratorImpl)enumerator).getClient().getConfig().getNamingStrategy(), - (v) -> new NamingStrategyMap<>(v, (u) -> getNameOrAnnotated(u.field, Field::getName), getFields()) + (v) -> new NamingStrategyMap<>(v, (u) -> getNameOrAnnotated(v, u, u.field, Field::getName), getFields()) ); var instance = (T)ctor.newInstance(); @@ -315,11 +345,15 @@ private boolean isValidField(@NotNull Field field) { public @NotNull NamingStrategyMap getFieldMap(NamingStrategy strategy) { return fieldNamingMap.computeIfAbsent( strategy, - (v) -> new NamingStrategyMap<>(v, (u) -> getNameOrAnnotated(u.field, Field::getName), getFields()) + (v) -> new NamingStrategyMap<>(v, (u) -> getNameOrAnnotated(v, u, u.field, Field::getName), getFields()) ); } - private String getNameOrAnnotated(@NotNull U value, @NotNull Function getName) { + private String getNameOrAnnotated(NamingStrategy strategy, V root, @NotNull U value, @NotNull Function getName) { + if(root instanceof FieldInfo) { + return ((FieldInfo)root).getEdgeDBName(strategy); + } + var anno = value.getAnnotation(EdgeDBName.class); if(anno != null && anno.value() != null) { return anno.value(); @@ -329,12 +363,12 @@ private String getNameOrAnnotated(@NotNull U value, } public static class FieldInfo { - public final EdgeDBName edgedbNameAnno; + public EdgeDBName edgedbNameAnno; public final @NotNull Class fieldType; public final @NotNull Field field; private final @Nullable Method setMethod; - private final @Nullable EdgeDBLinkType linkType; + private @Nullable EdgeDBLinkType linkType; public FieldInfo(@NotNull Field field, @NotNull Map setters) { this.field = field; @@ -394,8 +428,12 @@ private Class extractCollectionInnerType(@NotNull Class cls) throws EdgeDB throw new EdgeDBException("Cannot find element type of the collection " + cls.getName()); } - public @NotNull String getFieldName() { - return this.field.getName(); + public @NotNull String getEdgeDBName(NamingStrategy strategy) { + if(this.edgedbNameAnno != null && this.edgedbNameAnno.value() != null) { + return this.edgedbNameAnno.value(); + } + + return strategy.convert(this.field.getName()); } public void convertAndSet(boolean useMethodSetter, Object instance, Object value) throws EdgeDBException, ReflectiveOperationException { @@ -426,6 +464,16 @@ public void convertAndSet(boolean useMethodSetter, Object instance, Object value return ObjectBuilder.convertTo(fieldType, value); } + + private void updateFromDeserializerParameter(Parameter parameter) { + if(this.edgedbNameAnno == null) { + this.edgedbNameAnno = parameter.getAnnotation(EdgeDBName.class); + } + + if(this.linkType == null) { + this.linkType = parameter.getAnnotation(EdgeDBLinkType.class); + } + } } public static class NamingStrategyMap { diff --git a/src/driver/src/test/java/CustomDeserializerTests.java b/src/driver/src/test/java/CustomDeserializerTests.java new file mode 100644 index 0000000..ced96f7 --- /dev/null +++ b/src/driver/src/test/java/CustomDeserializerTests.java @@ -0,0 +1,53 @@ +import com.edgedb.driver.EdgeDBClient; +import com.edgedb.driver.annotations.EdgeDBDeserializer; +import com.edgedb.driver.annotations.EdgeDBLinkType; +import com.edgedb.driver.annotations.EdgeDBName; +import com.edgedb.driver.annotations.EdgeDBType; +import org.junit.jupiter.api.Test; + +import java.util.Collection; + +import static org.assertj.core.api.Assertions.assertThat; + +public class CustomDeserializerTests { + + @EdgeDBType + public static final class Links { + private String namesHereAreIrrelevant; + private Links sinceTheCustomDeserializer; + private Collection shouldMapNames; + + @EdgeDBDeserializer + public Links( + @EdgeDBName("a") String a, + @EdgeDBName("b") Links b, + @EdgeDBName("c") @EdgeDBLinkType(Links.class) Collection c + ) { + this.namesHereAreIrrelevant = a; + this.sinceTheCustomDeserializer = b; + this.shouldMapNames = c; + } + } + + @Test + public void testCustomDeserializerParameterAnnotations() throws Exception { + try(var client = new EdgeDBClient().withModule("tests")) { + var result = client.queryRequiredSingle( + Links.class, + "with test1 := (insert Links { a := '123' } unless conflict on .a else (select Links))," + + "test2 := (insert Links { a := '456', b := test1 } unless conflict on .a else (select Links))," + + "test3 := (insert Links { a := '789', b := test2, c := { test1, test2 }} unless conflict on .a else (select Links)) " + + "select test3 {a, b: {a, b, c }, c: {a, b, c}}" + ).toCompletableFuture().get(); + + assertThat(result.namesHereAreIrrelevant).isEqualTo("789"); + assertThat(result.shouldMapNames.size()).isEqualTo(2); + + for(var link : result.shouldMapNames) { + assertThat(link.getClass()).isEqualTo(Links.class); + } + + assertThat(result.sinceTheCustomDeserializer).isNotNull(); + } + } +}