diff --git a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java index 79b2bfb4e29e..0481ce2f4b1d 100644 --- a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java +++ b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java @@ -581,7 +581,7 @@ public enum CassandraRelevantProperties * The current messaging version. This is used when we add new messaging versions without adopting them immediately, * or to force the node to use a specific version for testing purposes. */ - DS_CURRENT_MESSAGING_VERSION("ds.current_messaging_version", Integer.toString(MessagingService.VERSION_DS_10)); + DS_CURRENT_MESSAGING_VERSION("ds.current_messaging_version", Integer.toString(MessagingService.VERSION_DS_11)); CassandraRelevantProperties(String key, String defaultVal) { diff --git a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java index b20d658867e8..e1f29a20d0b9 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java @@ -1000,7 +1000,7 @@ public RowFilter getRowFilter(IndexRegistry indexManager, QueryOptions options, return RowFilter.NONE; } - ANNOptions annOptions = selectOptions.parseANNOptions(); + ANNOptions annOptions = selectOptions.parseANNOptions(table.keyspace); RowFilter rowFilter = RowFilter.builder(indexManager) .buildFromRestrictions(this, table, options, queryState, annOptions); diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectOptions.java b/src/java/org/apache/cassandra/cql3/statements/SelectOptions.java index 6c7c8f2e631d..58ba92f377e7 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectOptions.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectOptions.java @@ -22,6 +22,7 @@ import org.apache.cassandra.db.filter.ANNOptions; import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.exceptions.RequestValidationException; +import org.apache.cassandra.service.QueryState; /** * {@code WITH option1=... AND option2=...} options for SELECT statements. @@ -36,26 +37,27 @@ public class SelectOptions extends PropertyDefinitions /** * Validates all the {@code SELECT} options. * + * @param state the query state * @param limit the {@code SELECT} query user-provided limit * @throws InvalidRequestException if any of the options are invalid */ - public void validate(int limit) throws RequestValidationException + public void validate(QueryState state, String keyspace, int limit) throws RequestValidationException { validate(keywords, Collections.emptySet()); - parseANNOptions().validate(limit); + parseANNOptions(keyspace).validate(state, limit); } /** * @return the ANN options within these options, or {@link ANNOptions#NONE} if no options are present * @throws InvalidRequestException if the ANN options are invalid */ - public ANNOptions parseANNOptions() throws RequestValidationException + public ANNOptions parseANNOptions(String keyspace) throws RequestValidationException { Map options = getMap(ANN_OPTIONS); return options == null ? ANNOptions.NONE - : ANNOptions.fromMap(options); + : ANNOptions.fromMap(keyspace, options); } /** diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index 58b75aacab5a..7b63d6402bba 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -426,7 +426,7 @@ public ReadQuery getQuery(QueryState queryState, checkFalse(userOffset != NO_OFFSET, String.format(TOPK_OFFSET_ERROR, userOffset)); } - selectOptions.validate(userLimit); + selectOptions.validate(queryState, table.keyspace, userLimit); return query; } diff --git a/src/java/org/apache/cassandra/db/filter/ANNOptions.java b/src/java/org/apache/cassandra/db/filter/ANNOptions.java index ea74ff6f1136..0ffaf951b958 100644 --- a/src/java/org/apache/cassandra/db/filter/ANNOptions.java +++ b/src/java/org/apache/cassandra/db/filter/ANNOptions.java @@ -24,10 +24,12 @@ import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.exceptions.InvalidRequestException; +import org.apache.cassandra.guardrails.Guardrails; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.service.QueryState; import org.apache.cassandra.utils.FBUtilities; /** @@ -60,10 +62,15 @@ public static ANNOptions create(@Nullable Integer rerankK) return rerankK == null ? NONE : new ANNOptions(rerankK); } - public void validate(int limit) + public void validate(QueryState state, int limit) { - if (rerankK != null && rerankK > 0 && rerankK < limit) + if (rerankK == null) + return; + + if (rerankK < limit) throw new InvalidRequestException(String.format("Invalid rerank_k value %d lesser than limit %d", rerankK, limit)); + + Guardrails.annRerankKMaxValue.guard(rerankK, "ANN options", false, state); } /** @@ -72,10 +79,11 @@ public void validate(int limit) * @param map the map of options in the {@code WITH ANN_OPTION} of a {@code SELECT} query * @return the ANN options in the specified {@code SELECT} options, or {@link #NONE} if no options are present */ - public static ANNOptions fromMap(Map map) + public static ANNOptions fromMap(String keyspace, Map map) { + assert keyspace != null; // ensure that all nodes in the cluster are in a version that supports ANN options, including this one - Set badNodes = MessagingService.instance().endpointsWithVersionBelow(MessagingService.VERSION_DS_11); + Set badNodes = MessagingService.instance().endpointsWithVersionBelow(keyspace, MessagingService.VERSION_DS_11); if (MessagingService.current_version < MessagingService.VERSION_DS_11) badNodes.add(FBUtilities.getBroadcastAddressAndPort()); if (!badNodes.isEmpty()) diff --git a/src/java/org/apache/cassandra/guardrails/Guardrails.java b/src/java/org/apache/cassandra/guardrails/Guardrails.java index d02943c74385..845a00c2fd86 100644 --- a/src/java/org/apache/cassandra/guardrails/Guardrails.java +++ b/src/java/org/apache/cassandra/guardrails/Guardrails.java @@ -120,6 +120,17 @@ what, formatSize(v), formatSize(t))) format("%s has a vector of %s dimensions, this exceeds the %s threshold of %s.", what, value, isWarning ? "warning" : "failure", threshold)); + /** + * Guardrail on the maximum value for the rerank_k parameter, an ANN query option. + */ + public static final Threshold annRerankKMaxValue = + factory.threshold("sai_ann_rerank_k_max_value", + () -> config.sai_ann_rerank_k_warn_threshold, + () -> config.sai_ann_rerank_k_fail_threshold, + (isWarning, what, value, threshold) -> + format("%s specifies rerank_k=%s, this exceeds the %s threshold of %s.", + what, value, isWarning ? "warning" : "failure", threshold)); + public static final DisableFlag readBeforeWriteListOperationsEnabled = factory.disableFlag("read_before_write_list_operations", () -> !config.read_before_write_list_operations_enabled, diff --git a/src/java/org/apache/cassandra/guardrails/GuardrailsConfig.java b/src/java/org/apache/cassandra/guardrails/GuardrailsConfig.java index 015dad51ee82..2338de2ca3a3 100644 --- a/src/java/org/apache/cassandra/guardrails/GuardrailsConfig.java +++ b/src/java/org/apache/cassandra/guardrails/GuardrailsConfig.java @@ -29,6 +29,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; +import org.apache.cassandra.config.CassandraRelevantProperties; import org.apache.cassandra.config.Config; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.cql3.statements.schema.TableAttributes; @@ -81,6 +82,8 @@ public class GuardrailsConfig public volatile Boolean read_before_write_list_operations_enabled; public volatile Integer vector_dimensions_warn_threshold; public volatile Integer vector_dimensions_failure_threshold; + public volatile Integer sai_ann_rerank_k_warn_threshold; + public volatile Integer sai_ann_rerank_k_fail_threshold; // Legacy 2i guardrail public volatile Integer secondary_index_per_table_failure_threshold; @@ -165,6 +168,11 @@ public void applyConfig() enforceDefault(tombstone_warn_threshold, v -> tombstone_warn_threshold = v, 1000, 1000); enforceDefault(tombstone_failure_threshold, v -> tombstone_failure_threshold = v, 100000, 100000); + // Default to no warning and failure at 4 times the maxTopK value + int maxTopK = CassandraRelevantProperties.SAI_VECTOR_SEARCH_MAX_TOP_K.getInt(); + enforceDefault(sai_ann_rerank_k_warn_threshold, v -> sai_ann_rerank_k_warn_threshold = v, -1, -1); + enforceDefault(sai_ann_rerank_k_fail_threshold, v -> sai_ann_rerank_k_fail_threshold = v, 4 * maxTopK, 4 * maxTopK); + // for write requests enforceDefault(logged_batch_enabled, v -> logged_batch_enabled = v, true, true); enforceDefault(batch_size_warn_threshold_in_kb, v -> batch_size_warn_threshold_in_kb = v, 64, 64); @@ -269,6 +277,10 @@ public void validate() validateStrictlyPositiveInteger(vector_dimensions_failure_threshold, "vector_dimensions_failure_threshold"); validateWarnLowerThanFail(vector_dimensions_warn_threshold, vector_dimensions_failure_threshold, "vector_dimensions"); + validateStrictlyPositiveInteger(sai_ann_rerank_k_warn_threshold, "sai_ann_rerank_k_warn_threshold"); + validateStrictlyPositiveInteger(sai_ann_rerank_k_fail_threshold, "sai_ann_rerank_k_fail_threshold"); + validateWarnLowerThanFail(sai_ann_rerank_k_warn_threshold, sai_ann_rerank_k_fail_threshold, "sai_ann_rerank_k"); + validateStrictlyPositiveInteger(tables_warn_threshold, "tables_warn_threshold"); validateStrictlyPositiveInteger(tables_failure_threshold, "tables_failure_threshold"); validateWarnLowerThanFail(tables_warn_threshold, tables_failure_threshold, "tables"); diff --git a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java index 28cb42ab8dc8..6728ed7f23c7 100644 --- a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java +++ b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java @@ -711,10 +711,6 @@ public void validate(ReadCommand command) throws InvalidRequestException throw new InvalidRequestException(String.format("SAI based ORDER BY clause requires a LIMIT that is not greater than %s. LIMIT was %s", MAX_TOP_K, command.limits().isUnlimited() ? "NO LIMIT" : command.limits().count())); - ANNOptions annOptions = command.rowFilter().annOptions(); - if (annOptions != ANNOptions.NONE) - throw new InvalidRequestException("SAI doesn't support ANN options yet."); - indexContext.validate(command.rowFilter()); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java b/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java index 87fe0a998e5f..532c2c38f2b5 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java @@ -205,10 +205,11 @@ public String toString() * the number of candidates, the more nodes we expect to visit just to find * results that are in that set.) */ - public double estimateAnnSearchCost(int limit, int candidates) + public double estimateAnnSearchCost(Orderer orderer, int limit, int candidates) { - IndexSearcher searcher = getIndexSearcher(); - return ((V2VectorIndexSearcher) searcher).estimateAnnSearchCost(limit, candidates); + V2VectorIndexSearcher searcher = (V2VectorIndexSearcher) getIndexSearcher(); + int rerankK = orderer.rerankKFor(limit, searcher.getCompression()); + return searcher.estimateAnnSearchCost(rerankK, candidates); } /** diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/V1SearchableIndex.java b/src/java/org/apache/cassandra/index/sai/disk/v1/V1SearchableIndex.java index b7b0ef0222f3..28f90d590f8c 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/V1SearchableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/V1SearchableIndex.java @@ -207,6 +207,8 @@ public List> orderBy(Orderer orderer, E { if (segment.intersects(keyRange)) { + // Note that the proportionality is not used when the user supplies a rerank_k value in the + // ANN_OPTIONS map. var segmentLimit = segment.proportionalAnnLimit(limit, totalRows); iterators.add(segment.orderBy(orderer, slice, keyRange, context, segmentLimit)); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java index 80870dab36bf..7b9057f56834 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java @@ -163,7 +163,7 @@ public CloseableIterator orderBy(Orderer orderer, Express if (orderer.vector == null) throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression during ANN index query: " + orderer)); - int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression()); + int rerankK = orderer.rerankKFor(limit, graph.getCompression()); var queryVector = vts.createFloatVector(orderer.vector); var result = searchInternal(keyRange, context, queryVector, limit, rerankK, 0); @@ -428,9 +428,8 @@ public double cost() } } - public double estimateAnnSearchCost(int limit, int candidates) + public double estimateAnnSearchCost(int rerankK, int candidates) { - int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression()); var estimate = estimateCost(rerankK, candidates); return estimate.cost(); } @@ -472,7 +471,7 @@ public CloseableIterator orderResultsBy(SSTableReader rea if (keys.isEmpty()) return CloseableIterator.emptyIterator(); - int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression()); + int rerankK = orderer.rerankKFor(limit, graph.getCompression()); // Convert PKs to segment row ids and map to ordinals, skipping any that don't exist in this segment var segmentOrdinalPairs = flatmapPrimaryKeysToBitsAndRows(keys); var numRows = segmentOrdinalPairs.size(); @@ -611,9 +610,9 @@ public static double logBase2(double number) { return Math.log(number) / Math.log(2); } - private int getRawExpectedNodes(int limit, int nPermittedOrdinals) + private int getRawExpectedNodes(int rerankK, int nPermittedOrdinals) { - return VectorMemtableIndex.expectedNodesVisited(limit, nPermittedOrdinals, graph.size()); + return VectorMemtableIndex.expectedNodesVisited(rerankK, nPermittedOrdinals, graph.size()); } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java index 4a55991135e2..367c1ff5fc02 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java @@ -312,7 +312,7 @@ public void remove(ByteBuffer term, T key) /** * @return an itererator over {@link PrimaryKeyWithSortKey} in the graph's {@link SearchResult} order */ - public CloseableIterator search(QueryContext context, VectorFloat queryVector, int limit, float threshold, Bits toAccept) + public CloseableIterator search(QueryContext context, VectorFloat queryVector, int limit, int rerankK, float threshold, Bits toAccept) { VectorValidation.validateIndexable(queryVector, similarityFunction); @@ -326,7 +326,6 @@ public CloseableIterator search(QueryContext context, Ve try { var ssf = SearchScoreProvider.exact(queryVector, similarityFunction, vectorValues); - var rerankK = sourceModel.rerankKFor(limit, VectorCompression.NO_COMPRESSION); var result = searcher.search(ssf, limit, rerankK, threshold, 0.0f, bits); Tracing.trace("ANN search for {}/{} visited {} nodes, reranked {} to return {} results from {}", limit, rerankK, result.getVisitedCount(), result.getRerankedCount(), result.getNodes().length, source); diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java index 25231748149f..a8d85325e4a8 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java @@ -193,7 +193,7 @@ public KeyRangeIterator search(QueryContext context, Expression expr, AbstractBo float threshold = expr.getEuclideanSearchThreshold(); SortingIterator.Builder keyQueue; - try (var pkIterator = searchInternal(context, qv, keyRange, graph.size(), threshold)) + try (var pkIterator = searchInternal(context, qv, keyRange, graph.size(), graph.size(), threshold)) { keyQueue = new SortingIterator.Builder<>(); while (pkIterator.hasNext()) @@ -223,14 +223,16 @@ public List> orderBy(QueryContext conte assert orderer.isANN() : "Only ANN is supported for vector search, received " + orderer.operator; var qv = vts.createFloatVector(orderer.vector); + var rerankK = orderer.rerankKFor(limit, VectorCompression.NO_COMPRESSION); - return List.of(searchInternal(context, qv, keyRange, limit, 0)); + return List.of(searchInternal(context, qv, keyRange, limit, rerankK, 0)); } private CloseableIterator searchInternal(QueryContext context, VectorFloat queryVector, AbstractBounds keyRange, int limit, + int rerankK, float threshold) { Bits bits; @@ -256,11 +258,11 @@ private CloseableIterator searchInternal(QueryContext con if (resultKeys.isEmpty()) return CloseableIterator.emptyIterator(); - int bruteForceRows = maxBruteForceRows(limit, resultKeys.size(), graph.size()); - logger.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, LIMIT {}", - resultKeys.size(), bruteForceRows, graph.size(), limit); - Tracing.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, LIMIT {}", - resultKeys.size(), bruteForceRows, graph.size(), limit); + int bruteForceRows = maxBruteForceRows(rerankK, resultKeys.size(), graph.size()); + logger.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, rerankK {}, LIMIT {}", + resultKeys.size(), bruteForceRows, graph.size(), rerankK, limit); + Tracing.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, rerankK {}, LIMIT {}", + resultKeys.size(), bruteForceRows, graph.size(), rerankK, limit); if (resultKeys.size() <= bruteForceRows) // When we have a threshold, we only need to filter the results, not order them, because it means we're // evaluating a boolean predicate in the SAI pipeline that wants to collate by PK @@ -272,7 +274,7 @@ private CloseableIterator searchInternal(QueryContext con bits = new KeyRangeFilteringBits(keyRange); } - var nodeScoreIterator = graph.search(context, queryVector, limit, threshold, bits); + var nodeScoreIterator = graph.search(context, queryVector, limit, rerankK, threshold, bits); return new NodeScoreToScoredPrimaryKeyIterator(nodeScoreIterator); } @@ -305,9 +307,10 @@ public CloseableIterator orderResultsBy(QueryContext cont relevantOrdinals.add(i); }); - int maxBruteForceRows = maxBruteForceRows(limit, relevantOrdinals.size(), graph.size()); - Tracing.logAndTrace(logger, "{} rows relevant to current memtable out of {} materialized by SAI; max brute force rows is {} for memtable index with {} nodes, LIMIT {}", - relevantOrdinals.size(), keys.size(), maxBruteForceRows, graph.size(), limit); + int rerankK = orderer.rerankKFor(limit, VectorCompression.NO_COMPRESSION); + int maxBruteForceRows = maxBruteForceRows(rerankK, relevantOrdinals.size(), graph.size()); + Tracing.logAndTrace(logger, "{} rows relevant to current memtable out of {} materialized by SAI; max brute force rows is {} for memtable index with {} nodes, rerankK {}", + relevantOrdinals.size(), keys.size(), maxBruteForceRows, graph.size(), rerankK); // convert the expression value to query vector var qv = vts.createFloatVector(orderer.vector); @@ -319,7 +322,7 @@ public CloseableIterator orderResultsBy(QueryContext cont return orderByBruteForce(qv, keysInGraph); } // indexed path - var nodeScoreIterator = graph.search(context, qv, limit, 0, relevantOrdinals::contains); + var nodeScoreIterator = graph.search(context, qv, limit, rerankK, 0, relevantOrdinals::contains); return new NodeScoreToScoredPrimaryKeyIterator(nodeScoreIterator); } @@ -375,15 +378,15 @@ private PrimaryKeyWithScore scoreKey(VectorSimilarityFunction similarityFunction return new PrimaryKeyWithScore(indexContext, mt, key, score); } - private int maxBruteForceRows(int limit, int nPermittedOrdinals, int graphSize) + private int maxBruteForceRows(int rerankK, int nPermittedOrdinals, int graphSize) { - int expectedNodesVisited = expectedNodesVisited(limit, nPermittedOrdinals, graphSize); - return min(max(limit, expectedNodesVisited), GLOBAL_BRUTE_FORCE_ROWS); + int expectedNodesVisited = expectedNodesVisited(rerankK, nPermittedOrdinals, graphSize); + return min(max(rerankK, expectedNodesVisited), GLOBAL_BRUTE_FORCE_ROWS); } - public int estimateAnnNodesVisited(int limit, int nPermittedOrdinals) + public int estimateAnnNodesVisited(int rerankK, int nPermittedOrdinals) { - return expectedNodesVisited(limit, nPermittedOrdinals, graph.size()); + return expectedNodesVisited(rerankK, nPermittedOrdinals, graph.size()); } /** @@ -395,9 +398,9 @@ public int estimateAnnNodesVisited(int limit, int nPermittedOrdinals) * !!! roughly `degree` times larger than the number of nodes whose edge lists we load! * !!! */ - public static int expectedNodesVisited(int limit, int nPermittedOrdinals, int graphSize) + public static int expectedNodesVisited(int rerankK, int nPermittedOrdinals, int graphSize) { - var K = limit; + var K = rerankK; var B = min(nPermittedOrdinals, graphSize); var N = graphSize; // These constants come from running many searches on a variety of datasets and graph sizes. @@ -412,7 +415,7 @@ public static int expectedNodesVisited(int limit, int nPermittedOrdinals, int gr // If we need to make this even more accurate, the relationship to B and to log(N) may be the best // places to start. var raw = (int) (100 + 0.025 * pow(log(N), 2) * pow(K, 0.95) * ((double) N / B)); - return ensureSaneEstimate(raw, limit, graphSize); + return ensureSaneEstimate(raw, rerankK, graphSize); } public static int ensureSaneEstimate(int rawEstimate, int rerankK, int graphSize) diff --git a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java index e23c83ed91e2..76d7d0e92da4 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java @@ -31,6 +31,7 @@ import org.apache.cassandra.index.SecondaryIndexManager; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.StorageAttachedIndex; +import org.apache.cassandra.index.sai.disk.vector.VectorCompression; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; import org.apache.cassandra.index.sai.utils.TypeUtil; @@ -46,20 +47,25 @@ public class Orderer public final IndexContext context; public final Operator operator; + + // Vector search parameters public final float[] vector; + private final Integer rerankK; /** * Create an orderer for the given index context, operator, and term. * @param context the index context, used to build the view of memtables and sstables for query execution. * @param operator the operator for the order by clause. * @param term the term to order by (not always relevant) + * @param rerankK optional rerank K parameter for ANN queries */ - public Orderer(IndexContext context, Operator operator, ByteBuffer term) + public Orderer(IndexContext context, Operator operator, ByteBuffer term, @Nullable Integer rerankK) { this.context = context; assert ORDER_BY_OPERATORS.contains(operator) : "Invalid operator for order by clause " + operator; this.operator = operator; this.vector = context.getValidator().isVector() ? TypeUtil.decomposeVector(context.getValidator(), term) : null; + this.rerankK = rerankK; } public String getIndexName() @@ -89,16 +95,36 @@ public boolean isANN() return operator == Operator.ANN; } + /** + * Provide rerankK for ANN queries. Use the user provided rerankK if available, otherwise use the model's default + * based on the limit and compression type. + * + * @param limit the query limit or the proportional segment limit to use when calculating a reasonable rerankK + * default value + * @param vc the compression type of the vectors in the index + * @return the rerankK value to use in ANN search + */ + public int rerankKFor(int limit, VectorCompression vc) + { + assert isANN() : "rerankK is only valid for ANN queries"; + return rerankK != null + ? rerankK + : context.getIndexWriterConfig().getSourceModel().rerankKFor(limit, vc); + } + @Nullable public static Orderer from(SecondaryIndexManager indexManager, RowFilter filter) { var expressions = filter.root().expressions().stream().filter(Orderer::isFilterExpressionOrderer).collect(Collectors.toList()); if (expressions.isEmpty()) return null; - var orderRowFilter = expressions.get(0); - var index = indexManager.getBestIndexFor(orderRowFilter, StorageAttachedIndex.class) + var orderExpression = expressions.get(0); + var index = indexManager.getBestIndexFor(orderExpression, StorageAttachedIndex.class) .orElseThrow(() -> new IllegalStateException("No index found for order by clause")); - return new Orderer(index.getIndexContext(), orderRowFilter.operator(), orderRowFilter.getIndexValue()); + + // Null if not specified explicitly in the CQL query. + Integer rerankK = filter.annOptions().rerankK; + return new Orderer(index.getIndexContext(), orderExpression.operator(), orderExpression.getIndexValue(), rerankK); } public static boolean isFilterExpressionOrderer(RowFilter.Expression expression) @@ -110,8 +136,9 @@ public static boolean isFilterExpressionOrderer(RowFilter.Expression expression) public String toString() { String direction = isAscending() ? "ASC" : "DESC"; + String rerankInfo = rerankK != null ? String.format(" (rerank_k=%d)", rerankK) : ""; return isANN() - ? context.getColumnName() + " ANN OF " + Arrays.toString(vector) + ' ' + direction + ? context.getColumnName() + " ANN OF " + Arrays.toString(vector) + ' ' + direction + rerankInfo : context.getColumnName() + ' ' + direction; } } diff --git a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java index 15b0965fe82e..dd70b4296b15 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java +++ b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java @@ -64,6 +64,7 @@ import org.apache.cassandra.index.sai.StorageAttachedIndex; import org.apache.cassandra.index.sai.disk.format.IndexFeatureSet; import org.apache.cassandra.index.sai.disk.v1.Segment; +import org.apache.cassandra.index.sai.disk.vector.VectorCompression; import org.apache.cassandra.index.sai.disk.vector.VectorMemtableIndex; import org.apache.cassandra.index.sai.iterators.KeyRangeIntersectionIterator; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; @@ -937,20 +938,21 @@ private long estimateMatchingRowCountUsingIndex(Expression predicate) @Override - public double estimateAnnSearchCost(Orderer ordering, int limit, long candidates) + public double estimateAnnSearchCost(Orderer orderer, int limit, long candidates) { Preconditions.checkArgument(limit > 0, "limit must be > 0"); - IndexContext context = ordering.context; + IndexContext context = orderer.context; Collection memtables = context.getLiveMemtables().values(); View queryView = context.getView(); + int memoryRerankK = orderer.rerankKFor(limit, VectorCompression.NO_COMPRESSION); double cost = 0; for (MemtableIndex index : memtables) { // FIXME convert nodes visited to search cost int memtableCandidates = (int) Math.min(Integer.MAX_VALUE, candidates); - cost += ((VectorMemtableIndex) index).estimateAnnNodesVisited(limit, memtableCandidates); + cost += ((VectorMemtableIndex) index).estimateAnnNodesVisited(memoryRerankK, memtableCandidates); } long totalRows = 0; @@ -965,7 +967,7 @@ public double estimateAnnSearchCost(Orderer ordering, int limit, long candidates continue; int segmentLimit = segment.proportionalAnnLimit(limit, totalRows); int segmentCandidates = max(1, (int) (candidates * (double) segment.metadata.numRows / totalRows)); - cost += segment.estimateAnnSearchCost(segmentLimit, segmentCandidates); + cost += segment.estimateAnnSearchCost(orderer, segmentLimit, segmentCandidates); } } return cost; diff --git a/src/java/org/apache/cassandra/net/MessagingService.java b/src/java/org/apache/cassandra/net/MessagingService.java index c3cd867884d2..1865ef63d4ae 100644 --- a/src/java/org/apache/cassandra/net/MessagingService.java +++ b/src/java/org/apache/cassandra/net/MessagingService.java @@ -645,15 +645,17 @@ public void waitUntilListening() throws InterruptedException } /** - * Returns the endpoints that are known to be alive and are using a messaging version older than the given version. + * Returns the endpoints for the given keyspace that are known to be alive and are using a messaging version older + * than the given version. * + * @param keyspace a keyspace * @param version a messaging version - * @return a set of alive endpoints with messaging version below the given version + * @return a set of alive endpoints in the given keyspace with messaging version below the given version */ - public Set endpointsWithVersionBelow(int version) + public Set endpointsWithVersionBelow(String keyspace, int version) { Set nodes = new HashSet<>(); - for (InetAddressAndPort node : StorageService.instance.getTokenMetadata().getAllEndpoints()) + for (InetAddressAndPort node : StorageService.instance.getTokenMetadataForKeyspace(keyspace).getAllEndpoints()) { if (versions.knows(node) && versions.getRaw(node) < version) nodes.add(node); diff --git a/test/distributed/org/apache/cassandra/distributed/test/sai/ANNOptionsDistributedTest.java b/test/distributed/org/apache/cassandra/distributed/test/sai/ANNOptionsDistributedTest.java index 26052606a5e6..68de1aa12e3d 100644 --- a/test/distributed/org/apache/cassandra/distributed/test/sai/ANNOptionsDistributedTest.java +++ b/test/distributed/org/apache/cassandra/distributed/test/sai/ANNOptionsDistributedTest.java @@ -54,7 +54,8 @@ public void testANNOptionsWithAllDS11() throws Throwable .withConfig(config -> config.with(GOSSIP).with(NETWORK)) .start(), RF)) { - test(cluster, "SAI doesn't support ANN options yet."); + // null indicates that the query should succeed + testSelectWithAnnOptions(cluster, null); } } @@ -70,7 +71,7 @@ public void testANNOptionsWithAllDS10() throws Throwable .withConfig(config -> config.with(GOSSIP).with(NETWORK)) .start(), RF)) { - test(cluster, "ANN options are not supported in clusters below DS 11."); + testSelectWithAnnOptions(cluster, "ANN options are not supported in clusters below DS 11."); } } @@ -80,16 +81,18 @@ public void testANNOptionsWithAllDS10() throws Throwable @Test public void testANNOptionsWithMixedDS10AndDS11() throws Throwable { + assert CassandraRelevantProperties.DS_CURRENT_MESSAGING_VERSION.getInt() >= MessagingService.VERSION_DS_11; + try (Cluster cluster = init(Cluster.build(NUM_REPLICAS) .withInstanceInitializer(BB::install) .withConfig(config -> config.with(GOSSIP).with(NETWORK).with(NATIVE_PROTOCOL)) .start(), RF)) { - test(cluster, "ANN options are not supported in clusters below DS 11."); + testSelectWithAnnOptions(cluster, "ANN options are not supported in clusters below DS 11."); } } - private static void test(Cluster cluster, String expectedErrorMessage) + private static void testSelectWithAnnOptions(Cluster cluster, String expectedErrorMessage) { cluster.schemaChange(withKeyspace("CREATE TABLE %s.t (k int PRIMARY KEY, n int, v vector)")); cluster.schemaChange(withKeyspace("CREATE CUSTOM INDEX ON %s.t(v) USING 'StorageAttachedIndex'")); @@ -100,13 +103,16 @@ private static void test(Cluster cluster, String expectedErrorMessage) for (int i = 1; i <= cluster.size(); i++) { ICoordinator coordinator = cluster.coordinator(i); - Assertions.assertThatThrownBy(() -> coordinator.execute(select, ConsistencyLevel.ONE)) - .hasMessageContaining(expectedErrorMessage); + if (expectedErrorMessage == null) + coordinator.execute(select, ConsistencyLevel.ONE); + else + Assertions.assertThatThrownBy(() -> coordinator.execute(select, ConsistencyLevel.ONE)) + .hasMessageContaining(expectedErrorMessage); } } /** - * Injection to set the current version of the first cluster node to DS 11. + * Injection to set the current version of the first cluster node to DS 10. */ public static class BB { @@ -126,7 +132,7 @@ public static void install(ClassLoader classLoader, int node) @SuppressWarnings("unused") public static int currentVersion() { - return MessagingService.VERSION_DS_11; + return MessagingService.VERSION_DS_10; } } } diff --git a/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java b/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java index 791068fab39c..8b414f06bca0 100644 --- a/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java +++ b/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java @@ -17,12 +17,14 @@ import java.io.IOException; import java.util.Objects; +import java.util.Optional; import javax.annotation.Nullable; import org.junit.BeforeClass; import org.junit.Test; +import com.datastax.driver.core.exceptions.InvalidQueryException; import org.apache.cassandra.config.CassandraRelevantProperties; import org.apache.cassandra.cql3.CQLTester; import org.apache.cassandra.cql3.Operator; @@ -43,6 +45,7 @@ import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.IndexMetadata; +import org.apache.cassandra.transport.ProtocolVersion; import org.apache.cassandra.utils.Pair; import org.assertj.core.api.Assertions; import org.quicktheories.QuickTheory; @@ -88,12 +91,25 @@ public void testParseAndValidate() execute("SELECT * FROM %s WHERE k=0 ORDER BY v ANN OF [1, 1] WITH ann_options = {}"); // correct queries with specific ANN options - execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 0}"); + execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 10}"); + execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 11}"); execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 1000}"); - execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': '0'}"); execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': '1000'}"); - execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': '-1'}"); - execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': '-1000'}"); + + // Queries with invalid ann options that will eventually be valid when we support disabling reranking + assertInvalidThrowMessage("Invalid rerank_k value -1 lesser than limit 100", + InvalidRequestException.class, + "SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 100 WITH ann_options = {'rerank_k': -1}"); + assertInvalidThrowMessage("Invalid rerank_k value 0 lesser than limit 100", + InvalidRequestException.class, + "SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 100 WITH ann_options = {'rerank_k': 0}"); + + // Queries that exceed the failure threshold for the guardrail. Specifies a protocol version to trigger + // validation in the coordinator. + assertInvalidThrowMessage(Optional.of(ProtocolVersion.V5), + "ANN options specifies rerank_k=5000, this exceeds the failure threshold of 4000.", + InvalidQueryException.class, + "SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 5000}"); String baseQuery = "SELECT * FROM %s ORDER BY v ANN OF [1, 1]"; @@ -182,12 +198,13 @@ public void testTransport() testTransport("SELECT * FROM %s ORDER BY v ANN OF [1, 1]", ANNOptions.NONE); testTransport("SELECT * FROM %s ORDER BY v ANN OF [1, 1] WITH ann_options = {}", ANNOptions.NONE); + // TODO re-enable this test when we support negative rerank_k values // some random negative values, all should be accepted and not be mapped to NONE - String negativeQuery = "SELECT * FROM %%s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': %d}"; - QuickTheory.qt() - .withExamples(100) - .forAll(integers().allPositive()) - .checkAssert(i -> testTransport(String.format(negativeQuery, -i), ANNOptions.create(-i))); +// String negativeQuery = "SELECT * FROM %%s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': %d}"; +// QuickTheory.qt() +// .withExamples(100) +// .forAll(integers().allPositive()) +// .checkAssert(i -> testTransport(String.format(negativeQuery, -i), ANNOptions.create(-i))); // some random positive values, all should be accepted String positiveQuery = "SELECT * FROM %%s ORDER BY v ANN OF [1, 1] LIMIT %d WITH ann_options = {'rerank_k': % c.sai_ann_rerank_k_warn_threshold = v.intValue(), + "sai_ann_rerank_k_warn_threshold"); + + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold = -1; + testValidationOfStrictlyPositiveProperty((c, v) -> c.sai_ann_rerank_k_fail_threshold = v.intValue(), + "sai_ann_rerank_k_fail_threshold"); + } + + @Test + public void testDefaultValues() + { + // Reset to defaults + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold = defaultWarnThreshold; + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_fail_threshold = defaultFailThreshold; + + // Test that default failure threshold is 4 times the max top K + int maxTopK = CassandraRelevantProperties.SAI_VECTOR_SEARCH_MAX_TOP_K.getInt(); + assertEquals(-1, (int) DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold); + assertEquals(4 * maxTopK, (int) DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_fail_threshold); + } + + @Test + public void testSAIAnnRerankKThresholds() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + // Test values below and at warning threshold + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': 10}"); + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (WARN_THRESHOLD - 1) + '}'); + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + WARN_THRESHOLD + '}'); + + // Test values between warning and failure thresholds + assertWarns(String.format("ANN options specifies rerank_k=%d, this exceeds the warning threshold of %d.", + WARN_THRESHOLD + 1, WARN_THRESHOLD), + "SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (WARN_THRESHOLD + 1) + '}'); + + // Test values at failure threshold (should still warn) + assertWarns(String.format("ANN options specifies rerank_k=%d, this exceeds the warning threshold of %d.", + FAIL_THRESHOLD, WARN_THRESHOLD), + "SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + FAIL_THRESHOLD + '}'); + + // Test values above failure threshold + assertFails(String.format("ANN options specifies rerank_k=%d, this exceeds the failure threshold of %d.", + FAIL_THRESHOLD + 1, FAIL_THRESHOLD), + "SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (FAIL_THRESHOLD + 1) + '}'); + } + + @Test + public void testDisabledThresholds() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + // Test with warning threshold disabled + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold = -1; + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (WARN_THRESHOLD + 1) + '}'); + + // Test with failure threshold disabled + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_fail_threshold = -1; + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (FAIL_THRESHOLD + 1) + '}'); + } + + @Ignore // TODO: e-enable this test when we support negative rerank_k values + public void testNegativeRerankK() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + // Negative rerank_k values should be valid and not trigger warnings + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': -1}"); + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': -1000}"); + } + + @Test + public void testMissingRerankK() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + // Queries without rerank_k should be valid and not trigger warnings. + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10"); + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {}"); + } +} \ No newline at end of file diff --git a/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java b/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java index 34f3d6b6a321..c66f6889a7a3 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java @@ -68,9 +68,38 @@ public void testSiftSmall() throws Throwable double memoryRecall = testRecall(100, queryVectors, groundTruth); assertTrue("Memory recall is " + memoryRecall, memoryRecall > 0.975); + // Run a few queries with increasing rerank_k to validate that recall increases + ensureIncreasingRerankKIncreasesRecall(queryVectors, groundTruth); + flush(); var diskRecall = testRecall(100, queryVectors, groundTruth); assertTrue("Disk recall is " + diskRecall, diskRecall > 0.975); + + // Run a few queries with increasing rerank_k to validate that recall increases + ensureIncreasingRerankKIncreasesRecall(queryVectors, groundTruth); + } + + private void ensureIncreasingRerankKIncreasesRecall(List queryVectors, List> groundTruth) + { + // Validate that the recall increases as we increase the rerank_k parameter + double previousRecall = 0; + int limit = 10; + int strictlyIncreasedCount = 0; + // Testing shows that we acheive 100% recall at about rerank_k = 45, so no need to go higher + for (int rerankK = limit; rerankK <= 50; rerankK += 5) + { + var recall = testRecall(limit, queryVectors, groundTruth, rerankK); + // Recall varies, so we can only assert that it does not get worse on a per-run basis. However, it should + // get better strictly at least some of the time + assertTrue("Recall for rerank_k = " + rerankK + " is " + recall, recall >= previousRecall); + if (recall > previousRecall) + strictlyIncreasedCount++; + previousRecall = recall; + } + // This is a conservative assertion to prevent it from being too fragile. At the time of writing this test, + // we observed a strict increase of 6 times for in memory and 5 times for on disk. + assertTrue("Recall should have strictly increased at least 4 times but only increased " + strictlyIncreasedCount + " times", + strictlyIncreasedCount > 3); } @Test @@ -198,6 +227,11 @@ private static ArrayList> readIvecs(String filename) } public double testRecall(int topK, List queryVectors, List> groundTruth) + { + return testRecall(topK, queryVectors, groundTruth, null); + } + + public double testRecall(int topK, List queryVectors, List> groundTruth, Integer rerankK) { AtomicInteger topKfound = new AtomicInteger(0); @@ -208,7 +242,11 @@ public double testRecall(int topK, List queryVectors, List