Skip to content

Commit

Permalink
MongoDB: Filter Enforcer
Browse files Browse the repository at this point in the history
  • Loading branch information
nsaje committed May 22, 2024
1 parent bf30264 commit 47c334d
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.mongodb;

import io.trino.spi.TrinoException;
import org.bson.Document;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static io.trino.spi.StandardErrorCode.QUERY_REJECTED;

public class FilterEnforcer
{
private Map<String, List<String>> requiredFilters;

public FilterEnforcer(String requiredFiltersConfig)
{
this.requiredFilters = new HashMap<>();
if (requiredFiltersConfig == null) {
return;
}
for (String entry : requiredFiltersConfig.split(",")) {
String[] tokens = entry.split(":");
List<String> list = this.requiredFilters.computeIfAbsent(tokens[0], k -> new ArrayList<>());
list.add(tokens[1]);
}
}

public void checkAndRaiseIfInvalid(String collectionName, Document filter)
throws TrinoException
{
List<String> requiredFilters = this.requiredFilters.get(collectionName);
if (requiredFilters == null) {
return;
}
for (String requiredFilter : requiredFilters) {
if (requiredFilter != null && !contains(filter, requiredFilter)) {
throw new TrinoException(QUERY_REJECTED, "Collection '%s' requires a filter on '%s'!".formatted(collectionName, requiredFilter));
}
}
}

private static boolean contains(Object object, String requiredFilter)
{
if (object instanceof Document documentEntry) {
if (documentEntry.containsKey(requiredFilter)) {
return true;
}
else {
for (Object item : documentEntry.values()) {
if (contains(item, requiredFilter)) {
return true;
}
}
}
}
else if (object instanceof Iterable iterableEntry) {
for (Object item : iterableEntry) {
if (contains(item, requiredFilter)) {
return true;
}
}
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ public class MongoClientConfig
private boolean allowLocalScheduling;
private Duration dynamicFilteringWaitTimeout = new Duration(5, SECONDS);

private String requiredFilters;

@NotNull
public String getSchemaCollection()
{
Expand Down Expand Up @@ -285,4 +287,17 @@ public MongoClientConfig setDynamicFilteringWaitTimeout(Duration dynamicFilterin
this.dynamicFilteringWaitTimeout = dynamicFilteringWaitTimeout;
return this;
}

public @Pattern(message = "Invalid 'required-filters'. Expected a comma-separated list of <key>:<value> pairs.", regexp = "^([^:,]+:[^:,]+,?)*$") String getRequiredFilters()
{
return requiredFilters;
}

@Config("mongodb.required-filters")
@ConfigDescription("Comma-separated list of <collection>:<key> pairs indicating a <collection> requires a filter on key <key>.")
public MongoClientConfig setRequiredFilters(String requiredFilters)
{
this.requiredFilters = requiredFilters;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ public class MongoSession
private final Cache<SchemaTableName, MongoTable> tableCache;
private final String implicitPrefix;

private final FilterEnforcer filterEnforcer;

public MongoSession(TypeManager typeManager, MongoClient client, MongoClientConfig config)
{
this.typeManager = requireNonNull(typeManager, "typeManager is null");
Expand All @@ -190,7 +192,7 @@ public MongoSession(TypeManager typeManager, MongoClient client, MongoClientConf
this.caseInsensitiveNameMatching = config.isCaseInsensitiveNameMatching();
this.cursorBatchSize = config.getCursorBatchSize();
this.implicitPrefix = requireNonNull(config.getImplicitRowFieldPrefix(), "config.getImplicitRowFieldPrefix() is null");

this.filterEnforcer = new FilterEnforcer(config.getRequiredFilters());
this.tableCache = EvictableCacheBuilder.newBuilder()
.expireAfterWrite(1, MINUTES) // TODO: Configure
.build();
Expand Down Expand Up @@ -525,6 +527,8 @@ public MongoCursor<Document> execute(MongoTableHandle tableHandle, List<MongoCol
tableHandle.limit().ifPresent(iterable::limit);
log.debug("Find documents: collection: %s, filter: %s, projection: %s", tableHandle.schemaTableName(), filter, projection);

this.filterEnforcer.checkAndRaiseIfInvalid(collection.getNamespace().getCollectionName(), filter);

if (cursorBatchSize != 0) {
iterable.batchSize(cursorBatchSize);
}
Expand Down

0 comments on commit 47c334d

Please sign in to comment.