diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/FilterEnforcer.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/FilterEnforcer.java new file mode 100644 index 000000000000..675419f25033 --- /dev/null +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/FilterEnforcer.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mongodb; + +import io.trino.spi.TrinoException; +import org.bson.Document; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static io.trino.spi.StandardErrorCode.QUERY_REJECTED; + +public class FilterEnforcer +{ + private Map> requiredFilters; + + public FilterEnforcer(String requiredFiltersConfig) + { + this.requiredFilters = new HashMap<>(); + if (requiredFiltersConfig == null) { + return; + } + for (String entry : requiredFiltersConfig.split(",")) { + String[] tokens = entry.split(":"); + List list = this.requiredFilters.computeIfAbsent(tokens[0], k -> new ArrayList<>()); + list.add(tokens[1]); + } + } + + public void checkAndRaiseIfInvalid(String collectionName, Document filter) + throws TrinoException + { + List requiredFilters = this.requiredFilters.get(collectionName); + if (requiredFilters == null) { + return; + } + for (String requiredFilter : requiredFilters) { + if (requiredFilter != null && !contains(filter, requiredFilter)) { + throw new TrinoException(QUERY_REJECTED, "Collection '%s' requires a filter on '%s'!".formatted(collectionName, requiredFilter)); + } + } + } + + private static boolean contains(Object object, String requiredFilter) + { + if (object instanceof Document documentEntry) { + if (documentEntry.containsKey(requiredFilter)) { + return true; + } + else { + for (Object item : documentEntry.values()) { + if (contains(item, requiredFilter)) { + return true; + } + } + } + } + else if (object instanceof Iterable iterableEntry) { + for (Object item : iterableEntry) { + if (contains(item, requiredFilter)) { + return true; + } + } + } + return false; + } +} diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoClientConfig.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoClientConfig.java index 9db568d5a650..3ef9d3998d09 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoClientConfig.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoClientConfig.java @@ -52,6 +52,8 @@ public class MongoClientConfig private boolean allowLocalScheduling; private Duration dynamicFilteringWaitTimeout = new Duration(5, SECONDS); + private String requiredFilters; + @NotNull public String getSchemaCollection() { @@ -285,4 +287,17 @@ public MongoClientConfig setDynamicFilteringWaitTimeout(Duration dynamicFilterin this.dynamicFilteringWaitTimeout = dynamicFilteringWaitTimeout; return this; } + + public @Pattern(message = "Invalid 'required-filters'. Expected a comma-separated list of : pairs.", regexp = "^([^:,]+:[^:,]+,?)*$") String getRequiredFilters() + { + return requiredFilters; + } + + @Config("mongodb.required-filters") + @ConfigDescription("Comma-separated list of : pairs indicating a requires a filter on key .") + public MongoClientConfig setRequiredFilters(String requiredFilters) + { + this.requiredFilters = requiredFilters; + return this; + } } diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java index abe12f937a21..08011b1d2cb9 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java @@ -182,6 +182,8 @@ public class MongoSession private final Cache tableCache; private final String implicitPrefix; + private final FilterEnforcer filterEnforcer; + public MongoSession(TypeManager typeManager, MongoClient client, MongoClientConfig config) { this.typeManager = requireNonNull(typeManager, "typeManager is null"); @@ -190,7 +192,7 @@ public MongoSession(TypeManager typeManager, MongoClient client, MongoClientConf this.caseInsensitiveNameMatching = config.isCaseInsensitiveNameMatching(); this.cursorBatchSize = config.getCursorBatchSize(); this.implicitPrefix = requireNonNull(config.getImplicitRowFieldPrefix(), "config.getImplicitRowFieldPrefix() is null"); - + this.filterEnforcer = new FilterEnforcer(config.getRequiredFilters()); this.tableCache = EvictableCacheBuilder.newBuilder() .expireAfterWrite(1, MINUTES) // TODO: Configure .build(); @@ -525,6 +527,8 @@ public MongoCursor execute(MongoTableHandle tableHandle, List