Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix custom deserializer annotations being ignored #29

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/scala-examples/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ThisBuild / version := "0.1.0-SNAPSHOT"
ThisBuild / scalaVersion := "3.1.3"

libraryDependencies ++= Seq(
"com.edgedb" % "driver" % "0.2.3" from "file:///" + System.getProperty("user.dir") + "/lib/com.edgedb.driver-0.2.3.jar",
"com.edgedb" % "driver" % "0.2.4" from "file:///" + System.getProperty("user.dir") + "/lib/com.edgedb.driver-0.2.4-SNAPSHOT.jar",
"ch.qos.logback" % "logback-classic" % "1.4.7",
"ch.qos.logback" % "logback-core" % "1.4.7",
"com.fasterxml.jackson.core" % "jackson-databind" % "2.15.1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,59 +179,89 @@ private boolean isValidField(@NotNull Field field) {
}

@SuppressWarnings("unchecked")
private @NotNull TypeDeserializerFactory<T> createFactory() throws ReflectiveOperationException {
// check for constructor deserializer
var constructors = this.type.getDeclaredConstructors();
private @Nullable TypeDeserializerFactory<T> createFactoryFromBestConstructor(Constructor<?>[] constructors) {
var fields = getFields();

var ctorDeserializer = Arrays.stream(constructors).filter(x -> x.getAnnotation(EdgeDBDeserializer.class) != null).findFirst();

if(ctorDeserializer.isPresent()) {
var ctor = ctorDeserializer.get();
for(var i = 0; i != constructors.length; i++) {
var ctor = constructors[i];
var ctorParams = ctor.getParameters();

if(ctorParams.length == 1 && ctorParams[0].getType().equals(ObjectEnumerator.class)) {
return (enumerator, parent) -> (T)ctor.newInstance(enumerator);
}

return (enumerator, parent) -> {
var namingStrategyEntry = constructorNamingMap.computeIfAbsent(
((ObjectEnumeratorImpl)enumerator).getClient().getConfig().getNamingStrategy(),
(n) -> new NamingStrategyMap<>(n, (v) -> getNameOrAnnotated(v, Parameter::getName), ctor.getParameters())
);
if(fields.size() == ctorParams.length) {
// assert that all types match
var valid = true;
for(var j = 0; j != fields.size(); j++) {
if(!ctorParams[j].getType().isAssignableFrom(fields.get(j).fieldType)) {
valid = false;
break;
}
}

var params = new Object[namingStrategyEntry.nameIndexMap.size()];
var inverseIndexer = new FastInverseIndexer(params.length);
if(!valid) continue;

ObjectEnumerator.ObjectElement element;
// update each field with annotations from the parameters
for(var j = 0; j != fields.size(); j++) {
fields.get(j).updateFromDeserializerParameter(ctorParams[j]);
}

var unhandled = new Vector<ObjectEnumerator.ObjectElement>(params.length);
return (enumerator, parent) -> {
var namingStrategyEntry = constructorNamingMap.computeIfAbsent(
((ObjectEnumeratorImpl)enumerator).getClient().getConfig().getNamingStrategy(),
(n) -> new NamingStrategyMap<>(n, (v) -> getNameOrAnnotated(n, v, v, Parameter::getName), ctor.getParameters())
);

var params = new Object[namingStrategyEntry.nameIndexMap.size()];
var inverseIndexer = new FastInverseIndexer(params.length);

ObjectEnumerator.ObjectElement element;

while(enumerator.hasRemaining() && (element = enumerator.next()) != null) {
if(namingStrategyEntry.map.containsKey(element.getName())) {
var i = namingStrategyEntry.nameIndexMap.get(element.getName());
inverseIndexer.set(i);
params[i] = element.getValue();
} else {
unhandled.add(element);
var unhandled = new Vector<ObjectEnumerator.ObjectElement>(params.length);

while(enumerator.hasRemaining() && (element = enumerator.next()) != null) {
if(namingStrategyEntry.map.containsKey(element.getName())) {
var namingIndex = namingStrategyEntry.nameIndexMap.get(element.getName());
var field = fields.get(namingIndex);
inverseIndexer.set(namingIndex);
params[namingIndex] = field.convertToType(element.getValue());
} else {
unhandled.add(element);
}
}
}

var missed = inverseIndexer.getInverseIndexes();
var missed = inverseIndexer.getInverseIndexes();

for(int i = 0; i != missed.length; i++) {
params[missed[i]] = TypeUtils.getDefaultValue(namingStrategyEntry.values.get(i).getType());
}
for(int j = 0; j != missed.length; j++) {
params[missed[j]] = TypeUtils.getDefaultValue(namingStrategyEntry.values.get(j).getType());
}

var instance = (T)ctor.newInstance(params);
var instance = (T)ctor.newInstance(params);

if(parent != null) {
for (var unhandledElement : unhandled) {
parent.accept(instance, unhandledElement);
if(parent != null) {
for (var unhandledElement : unhandled) {
parent.accept(instance, unhandledElement);
}
}
}

return instance;
};
return instance;
};
}
}

return null;
}

@SuppressWarnings("unchecked")
private @NotNull TypeDeserializerFactory<T> createFactory() throws ReflectiveOperationException {
// check for constructor deserializer
var constructors = this.type.getDeclaredConstructors();

var deserializer = createFactoryFromBestConstructor(constructors);

if(deserializer != null) {
return deserializer;
}

// abstract or interface
Expand All @@ -246,7 +276,7 @@ private boolean isValidField(@NotNull Field field) {

var namingStrategyEntry = fieldNamingMap.computeIfAbsent(
((ObjectEnumeratorImpl)enumerator).getClient().getConfig().getNamingStrategy(),
(v) -> new NamingStrategyMap<>(v, (u) -> getNameOrAnnotated(u.field, Field::getName), getFields())
(v) -> new NamingStrategyMap<>(v, (u) -> getNameOrAnnotated(v, u, u.field, Field::getName), getFields())
);

var element = enumerator.next();
Expand Down Expand Up @@ -293,7 +323,7 @@ private boolean isValidField(@NotNull Field field) {
return (enumerator, parent) -> {
var namingStrategyEntry = fieldNamingMap.computeIfAbsent(
((ObjectEnumeratorImpl)enumerator).getClient().getConfig().getNamingStrategy(),
(v) -> new NamingStrategyMap<>(v, (u) -> getNameOrAnnotated(u.field, Field::getName), getFields())
(v) -> new NamingStrategyMap<>(v, (u) -> getNameOrAnnotated(v, u, u.field, Field::getName), getFields())
);

var instance = (T)ctor.newInstance();
Expand All @@ -315,11 +345,15 @@ private boolean isValidField(@NotNull Field field) {
public @NotNull NamingStrategyMap<FieldInfo> getFieldMap(NamingStrategy strategy) {
return fieldNamingMap.computeIfAbsent(
strategy,
(v) -> new NamingStrategyMap<>(v, (u) -> getNameOrAnnotated(u.field, Field::getName), getFields())
(v) -> new NamingStrategyMap<>(v, (u) -> getNameOrAnnotated(v, u, u.field, Field::getName), getFields())
);
}

private <U extends AnnotatedElement> String getNameOrAnnotated(@NotNull U value, @NotNull Function<U, String> getName) {
private <U extends AnnotatedElement, V> String getNameOrAnnotated(NamingStrategy strategy, V root, @NotNull U value, @NotNull Function<U, String> getName) {
if(root instanceof FieldInfo) {
return ((FieldInfo)root).getEdgeDBName(strategy);
}

var anno = value.getAnnotation(EdgeDBName.class);
if(anno != null && anno.value() != null) {
return anno.value();
Expand All @@ -329,12 +363,12 @@ private <U extends AnnotatedElement> String getNameOrAnnotated(@NotNull U value,
}

public static class FieldInfo {
public final EdgeDBName edgedbNameAnno;
public EdgeDBName edgedbNameAnno;
public final @NotNull Class<?> fieldType;
public final @NotNull Field field;
private final @Nullable Method setMethod;

private final @Nullable EdgeDBLinkType linkType;
private @Nullable EdgeDBLinkType linkType;

public FieldInfo(@NotNull Field field, @NotNull Map<String, Method> setters) {
this.field = field;
Expand Down Expand Up @@ -394,8 +428,12 @@ private Class<?> extractCollectionInnerType(@NotNull Class<?> cls) throws EdgeDB
throw new EdgeDBException("Cannot find element type of the collection " + cls.getName());
}

public @NotNull String getFieldName() {
return this.field.getName();
public @NotNull String getEdgeDBName(NamingStrategy strategy) {
if(this.edgedbNameAnno != null && this.edgedbNameAnno.value() != null) {
return this.edgedbNameAnno.value();
}

return strategy.convert(this.field.getName());
}

public void convertAndSet(boolean useMethodSetter, Object instance, Object value) throws EdgeDBException, ReflectiveOperationException {
Expand Down Expand Up @@ -426,6 +464,16 @@ public void convertAndSet(boolean useMethodSetter, Object instance, Object value

return ObjectBuilder.convertTo(fieldType, value);
}

private void updateFromDeserializerParameter(Parameter parameter) {
if(this.edgedbNameAnno == null) {
this.edgedbNameAnno = parameter.getAnnotation(EdgeDBName.class);
}

if(this.linkType == null) {
this.linkType = parameter.getAnnotation(EdgeDBLinkType.class);
}
}
}

public static class NamingStrategyMap<T> {
Expand Down
53 changes: 53 additions & 0 deletions src/driver/src/test/java/CustomDeserializerTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import com.edgedb.driver.EdgeDBClient;
import com.edgedb.driver.annotations.EdgeDBDeserializer;
import com.edgedb.driver.annotations.EdgeDBLinkType;
import com.edgedb.driver.annotations.EdgeDBName;
import com.edgedb.driver.annotations.EdgeDBType;
import org.junit.jupiter.api.Test;

import java.util.Collection;

import static org.assertj.core.api.Assertions.assertThat;

public class CustomDeserializerTests {

@EdgeDBType
public static final class Links {
private String namesHereAreIrrelevant;
private Links sinceTheCustomDeserializer;
private Collection<Links> shouldMapNames;

@EdgeDBDeserializer
public Links(
@EdgeDBName("a") String a,
@EdgeDBName("b") Links b,
@EdgeDBName("c") @EdgeDBLinkType(Links.class) Collection<Links> c
) {
this.namesHereAreIrrelevant = a;
this.sinceTheCustomDeserializer = b;
this.shouldMapNames = c;
}
}

@Test
public void testCustomDeserializerParameterAnnotations() throws Exception {
try(var client = new EdgeDBClient().withModule("tests")) {
var result = client.queryRequiredSingle(
Links.class,
"with test1 := (insert Links { a := '123' } unless conflict on .a else (select Links))," +
"test2 := (insert Links { a := '456', b := test1 } unless conflict on .a else (select Links))," +
"test3 := (insert Links { a := '789', b := test2, c := { test1, test2 }} unless conflict on .a else (select Links)) " +
"select test3 {a, b: {a, b, c }, c: {a, b, c}}"
).toCompletableFuture().get();

assertThat(result.namesHereAreIrrelevant).isEqualTo("789");
assertThat(result.shouldMapNames.size()).isEqualTo(2);

for(var link : result.shouldMapNames) {
assertThat(link.getClass()).isEqualTo(Links.class);
}

assertThat(result.sinceTheCustomDeserializer).isNotNull();
}
}
}