Skip to content

Commit

Permalink
MongoDB: Support DBRef pushdown
Browse files Browse the repository at this point in the history
  • Loading branch information
nsaje committed May 20, 2024
1 parent 46c9472 commit bf30264
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
import java.util.List;
import java.util.Optional;

import static io.trino.plugin.mongodb.MongoSession.COLLECTION_NAME;
import static io.trino.plugin.mongodb.MongoSession.COLLECTION_NAME_NATIVE;
import static io.trino.plugin.mongodb.MongoSession.DATABASE_NAME;
import static io.trino.plugin.mongodb.MongoSession.DATABASE_NAME_NATIVE;
import static io.trino.plugin.mongodb.MongoSession.ID;
import static io.trino.plugin.mongodb.MongoSession.ID_NATIVE;
import static java.util.Objects.requireNonNull;

/**
Expand All @@ -35,9 +41,26 @@ public record MongoColumnHandle(String baseName, List<String> dereferenceNames,
public MongoColumnHandle
{
requireNonNull(baseName, "baseName is null");
dereferenceNames = ImmutableList.copyOf(requireNonNull(dereferenceNames, "dereferenceNames is null"));
requireNonNull(dereferenceNames, "dereferenceNames is null");
requireNonNull(type, "type is null");
requireNonNull(comment, "comment is null");

if (dbRefField) {
String leafColumnName = dereferenceNames.getLast();
String leafDBRefNativeName = switch (leafColumnName) {
case DATABASE_NAME -> DATABASE_NAME_NATIVE;
case COLLECTION_NAME -> COLLECTION_NAME_NATIVE;
case ID -> ID_NATIVE;
default -> leafColumnName;
};
dereferenceNames = ImmutableList.<String>builder()
.addAll(dereferenceNames.subList(0, dereferenceNames.size() - 1))
.add(leafDBRefNativeName)
.build();
}
else {
dereferenceNames = ImmutableList.copyOf(dereferenceNames);
}
}

public ColumnMetadata toColumnMetadata()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -723,9 +723,6 @@ private static boolean isSupportedForPushdown(ConnectorExpression connectorExpre
}
if (connectorExpression instanceof FieldDereference fieldDereference) {
RowType rowType = (RowType) fieldDereference.getTarget().getType();
if (isDBRefField(rowType)) {
return false;
}
Field field = rowType.getFields().get(fieldDereference.getField());
if (field.getName().isEmpty()) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,11 @@
import static io.airlift.slice.Slices.wrappedBuffer;
import static io.trino.plugin.base.util.JsonTypeUtil.jsonParse;
import static io.trino.plugin.mongodb.MongoSession.COLLECTION_NAME;
import static io.trino.plugin.mongodb.MongoSession.COLLECTION_NAME_NATIVE;
import static io.trino.plugin.mongodb.MongoSession.DATABASE_NAME;
import static io.trino.plugin.mongodb.MongoSession.DATABASE_NAME_NATIVE;
import static io.trino.plugin.mongodb.MongoSession.ID;
import static io.trino.plugin.mongodb.MongoSession.ID_NATIVE;
import static io.trino.plugin.mongodb.ObjectIdType.OBJECT_ID;
import static io.trino.plugin.mongodb.TypeUtils.isJsonType;
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
Expand Down Expand Up @@ -416,9 +419,9 @@ private static Object getDbRefValue(DBRef dbRefValue, MongoColumnHandle columnHa
checkState(!dereferenceNames.isEmpty(), "dereferenceNames is empty");
String leafColumnName = dereferenceNames.getLast();
return switch (leafColumnName) {
case DATABASE_NAME -> dbRefValue.getDatabaseName();
case COLLECTION_NAME -> dbRefValue.getCollectionName();
case ID -> dbRefValue.getId();
case DATABASE_NAME_NATIVE -> dbRefValue.getDatabaseName();
case COLLECTION_NAME_NATIVE -> dbRefValue.getCollectionName();
case ID_NATIVE -> dbRefValue.getId();
default -> throw new IllegalStateException("Unsupported DBRef column name: " + leafColumnName);
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ public class MongoSession
public static final String DATABASE_NAME = "databaseName";
public static final String COLLECTION_NAME = "collectionName";
public static final String ID = "id";
public static final String DATABASE_NAME_NATIVE = "$db";
public static final String COLLECTION_NAME_NATIVE = "$ref";
public static final String ID_NATIVE = "$id";

// The 'simple' locale is the default collection in MongoDB. The locale doesn't allow specifying other fields (e.g. numericOrdering)
// https://www.mongodb.com/docs/manual/reference/collation/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1536,7 +1536,7 @@ private void testProjectionPushdownWithDBRef(Object objectId, String expectedVal

assertThat(query("SELECT parent.child, creator.databaseName, creator.collectionName, creator.id FROM test." + tableName))
.matches("SELECT " + expectedValue + ", varchar 'test', varchar 'creators', " + expectedValue)
.isNotFullyPushedDown(ProjectNode.class);
.isFullyPushedDown();
assertQuery(
"SELECT typeof(creator) FROM test." + tableName,
"SELECT 'row(databaseName varchar, collectionName varchar, id " + expectedType + ")'");
Expand Down Expand Up @@ -1573,7 +1573,7 @@ private void testProjectionPushdownWithNestedDBRef(Object objectId, String expec

assertThat(query("SELECT parent.child, parent.creator.databaseName, parent.creator.collectionName, parent.creator.id FROM test." + tableName))
.matches("SELECT " + expectedValue + ", varchar 'test', varchar 'creators', " + expectedValue)
.isNotFullyPushedDown(ProjectNode.class);
.isFullyPushedDown();
assertQuery(
"SELECT typeof(parent.creator) FROM test." + tableName,
"SELECT 'row(databaseName varchar, collectionName varchar, id " + expectedType + ")'");
Expand Down Expand Up @@ -1609,7 +1609,7 @@ private void testProjectionPushdownWithPredefinedDBRefKeyword(Object objectId, S
assertThat(query("SELECT parent.id, parent.id.id FROM test." + tableName))
.skippingTypesCheck()
.matches("SELECT row('test', 'creators', %1$s), %1$s".formatted(expectedValue))
.isNotFullyPushedDown(ProjectNode.class);
.isFullyPushedDown();
assertQuery(
"SELECT typeof(parent.id), typeof(parent.id.id) FROM test." + tableName,
"SELECT 'row(databaseName varchar, collectionName varchar, id %1$s)', '%1$s'".formatted(expectedType));
Expand Down Expand Up @@ -1679,12 +1679,12 @@ private void testDBRefLikeDocument(Document document1, Document document2, Strin
assertThat(query("SELECT creator.id FROM test." + tableName))
.skippingTypesCheck()
.matches("VALUES (%1$s), (%1$s)".formatted(expectedValue))
.isNotFullyPushedDown(ProjectNode.class);
.isFullyPushedDown();

assertThat(query("SELECT creator.databasename, creator.collectionname, creator.id FROM test." + tableName))
.skippingTypesCheck()
.matches("VALUES ('doc_test', 'doc_creators', %1$s), ('dbref_test', 'dbref_creators', %1$s)".formatted(expectedValue))
.isNotFullyPushedDown(ProjectNode.class);
.isFullyPushedDown();

assertUpdate("DROP TABLE test." + tableName);
}
Expand All @@ -1693,14 +1693,14 @@ private static Document getDocumentWithDifferentDbRefFieldOrder(Object objectId)
{
return new Document()
.append("_id", new ObjectId("5126bbf64aed4daf9e2ab771"))
.append("creator", new Document().append("collectionName", "doc_creators").append("id", objectId).append("databaseName", "doc_test"));
.append("creator", new Document().append("$ref", "doc_creators").append("$id", objectId).append("$db", "doc_test"));
}

private static Document documentWithSameDbRefFieldOrder(Object objectId)
{
return new Document()
.append("_id", new ObjectId("5126bbf64aed4daf9e2ab771"))
.append("creator", new Document().append("databaseName", "doc_test").append("collectionName", "doc_creators").append("id", objectId));
.append("creator", new Document().append("$db", "doc_test").append("$ref", "doc_creators").append("$id", objectId));
}

private static Document dbRefDocument(Object objectId)
Expand All @@ -1717,9 +1717,9 @@ private void testDBRefLikeDocument(Object objectId, String expectedValue)
Document documentWithDifferentDbRefFieldOrder = new Document()
.append("_id", new ObjectId("5126bbf64aed4daf9e2ab771"))
.append("creator", new Document()
.append("databaseName", "doc_test")
.append("collectionName", "doc_creators")
.append("id", objectId));
.append("$db", "doc_test")
.append("$ref", "doc_creators")
.append("$id", objectId));
Document dbRefDocument = new Document()
.append("_id", new ObjectId("5126bbf64aed4daf9e2ab772"))
.append("creator", new DBRef("dbref_test", "dbref_creators", objectId));
Expand Down Expand Up @@ -1766,12 +1766,12 @@ private void testPredicateOnDBRefField(Object objectId, String expectedValue)
assertThat(query("SELECT * FROM test." + tableName + " WHERE creator.id = " + expectedValue))
.skippingTypesCheck()
.matches("SELECT ROW(varchar 'test', varchar 'creators', " + expectedValue + ")")
.isNotFullyPushedDown(FilterNode.class);
.isFullyPushedDown();

assertThat(query("SELECT creator.id FROM test." + tableName + " WHERE creator.id = " + expectedValue))
.skippingTypesCheck()
.matches("SELECT " + expectedValue)
.isNotFullyPushedDown(FilterNode.class);
.isFullyPushedDown();

assertUpdate("DROP TABLE test." + tableName);
}
Expand All @@ -1793,21 +1793,21 @@ private void testPredicateOnDBRefLikeDocument(Object objectId, String expectedVa
Document document = new Document()
.append("_id", new ObjectId("5126bbf64aed4daf9e2ab771"))
.append("creator", new Document()
.append("databaseName", "test")
.append("collectionName", "creators")
.append("id", objectId));
.append("$db", "test")
.append("$ref", "creators")
.append("$id", objectId));

client.getDatabase("test").getCollection(tableName).insertOne(document);

assertThat(query("SELECT * FROM test." + tableName + " WHERE creator.id = " + expectedValue))
.skippingTypesCheck()
.matches("SELECT ROW(varchar 'test', varchar 'creators', " + expectedValue + ")")
.isNotFullyPushedDown(FilterNode.class);
.isFullyPushedDown();

assertThat(query("SELECT creator.id FROM test." + tableName + " WHERE creator.id = " + expectedValue))
.skippingTypesCheck()
.matches("SELECT " + expectedValue)
.isNotFullyPushedDown(FilterNode.class);
.isFullyPushedDown();

assertUpdate("DROP TABLE test." + tableName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public void testRoundTripWithProjectedColumns()
false,
false,
Optional.empty()),
new MongoColumnHandle("creator", ImmutableList.of("databasename"), VARCHAR, false, true, Optional.empty()));
new MongoColumnHandle("creator", ImmutableList.of("databaseName"), VARCHAR, false, true, Optional.empty()));

MongoTableHandle expected = new MongoTableHandle(
schemaTableName,
Expand Down

0 comments on commit bf30264

Please sign in to comment.