From 41d23867785a3e2a69a31ae5136e0c9767330715 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Tue, 11 Feb 2025 15:36:13 -0600 Subject: [PATCH 01/12] CNDB-12922: Implement rerank_k in SAI ANN queries --- .../config/CassandraRelevantProperties.java | 2 +- .../index/sai/StorageAttachedIndex.java | 4 -- .../cassandra/index/sai/disk/v1/Segment.java | 7 ++-- .../index/sai/disk/v1/V1SearchableIndex.java | 2 + .../sai/disk/v2/V2VectorIndexSearcher.java | 7 ++-- .../sai/disk/vector/CassandraOnHeapGraph.java | 3 +- .../sai/disk/vector/VectorMemtableIndex.java | 11 +++-- .../cassandra/index/sai/plan/Orderer.java | 32 +++++++++++++-- .../index/sai/plan/QueryController.java | 6 +-- .../index/sai/cql/VectorSiftSmallTest.java | 40 ++++++++++++++++++- .../sai/memory/VectorMemtableIndexTest.java | 2 +- 11 files changed, 90 insertions(+), 26 deletions(-) diff --git a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java index 79b2bfb4e29e..0481ce2f4b1d 100644 --- a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java +++ b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java @@ -581,7 +581,7 @@ public enum CassandraRelevantProperties * The current messaging version. This is used when we add new messaging versions without adopting them immediately, * or to force the node to use a specific version for testing purposes. */ - DS_CURRENT_MESSAGING_VERSION("ds.current_messaging_version", Integer.toString(MessagingService.VERSION_DS_10)); + DS_CURRENT_MESSAGING_VERSION("ds.current_messaging_version", Integer.toString(MessagingService.VERSION_DS_11)); CassandraRelevantProperties(String key, String defaultVal) { diff --git a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java index 28cb42ab8dc8..6728ed7f23c7 100644 --- a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java +++ b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java @@ -711,10 +711,6 @@ public void validate(ReadCommand command) throws InvalidRequestException throw new InvalidRequestException(String.format("SAI based ORDER BY clause requires a LIMIT that is not greater than %s. LIMIT was %s", MAX_TOP_K, command.limits().isUnlimited() ? "NO LIMIT" : command.limits().count())); - ANNOptions annOptions = command.rowFilter().annOptions(); - if (annOptions != ANNOptions.NONE) - throw new InvalidRequestException("SAI doesn't support ANN options yet."); - indexContext.validate(command.rowFilter()); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java b/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java index 87fe0a998e5f..532c2c38f2b5 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java @@ -205,10 +205,11 @@ public String toString() * the number of candidates, the more nodes we expect to visit just to find * results that are in that set.) */ - public double estimateAnnSearchCost(int limit, int candidates) + public double estimateAnnSearchCost(Orderer orderer, int limit, int candidates) { - IndexSearcher searcher = getIndexSearcher(); - return ((V2VectorIndexSearcher) searcher).estimateAnnSearchCost(limit, candidates); + V2VectorIndexSearcher searcher = (V2VectorIndexSearcher) getIndexSearcher(); + int rerankK = orderer.rerankKFor(limit, searcher.getCompression()); + return searcher.estimateAnnSearchCost(rerankK, candidates); } /** diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/V1SearchableIndex.java b/src/java/org/apache/cassandra/index/sai/disk/v1/V1SearchableIndex.java index b7b0ef0222f3..28f90d590f8c 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/V1SearchableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/V1SearchableIndex.java @@ -207,6 +207,8 @@ public List> orderBy(Orderer orderer, E { if (segment.intersects(keyRange)) { + // Note that the proportionality is not used when the user supplies a rerank_k value in the + // ANN_OPTIONS map. var segmentLimit = segment.proportionalAnnLimit(limit, totalRows); iterators.add(segment.orderBy(orderer, slice, keyRange, context, segmentLimit)); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java index 80870dab36bf..d150fb1ea742 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java @@ -163,7 +163,7 @@ public CloseableIterator orderBy(Orderer orderer, Express if (orderer.vector == null) throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression during ANN index query: " + orderer)); - int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression()); + int rerankK = orderer.rerankKFor(limit, graph.getCompression()); var queryVector = vts.createFloatVector(orderer.vector); var result = searchInternal(keyRange, context, queryVector, limit, rerankK, 0); @@ -428,9 +428,8 @@ public double cost() } } - public double estimateAnnSearchCost(int limit, int candidates) + public double estimateAnnSearchCost(int rerankK, int candidates) { - int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression()); var estimate = estimateCost(rerankK, candidates); return estimate.cost(); } @@ -472,7 +471,7 @@ public CloseableIterator orderResultsBy(SSTableReader rea if (keys.isEmpty()) return CloseableIterator.emptyIterator(); - int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression()); + int rerankK = orderer.rerankKFor(limit, graph.getCompression()); // Convert PKs to segment row ids and map to ordinals, skipping any that don't exist in this segment var segmentOrdinalPairs = flatmapPrimaryKeysToBitsAndRows(keys); var numRows = segmentOrdinalPairs.size(); diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java index 4a55991135e2..367c1ff5fc02 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java @@ -312,7 +312,7 @@ public void remove(ByteBuffer term, T key) /** * @return an itererator over {@link PrimaryKeyWithSortKey} in the graph's {@link SearchResult} order */ - public CloseableIterator search(QueryContext context, VectorFloat queryVector, int limit, float threshold, Bits toAccept) + public CloseableIterator search(QueryContext context, VectorFloat queryVector, int limit, int rerankK, float threshold, Bits toAccept) { VectorValidation.validateIndexable(queryVector, similarityFunction); @@ -326,7 +326,6 @@ public CloseableIterator search(QueryContext context, Ve try { var ssf = SearchScoreProvider.exact(queryVector, similarityFunction, vectorValues); - var rerankK = sourceModel.rerankKFor(limit, VectorCompression.NO_COMPRESSION); var result = searcher.search(ssf, limit, rerankK, threshold, 0.0f, bits); Tracing.trace("ANN search for {}/{} visited {} nodes, reranked {} to return {} results from {}", limit, rerankK, result.getVisitedCount(), result.getRerankedCount(), result.getNodes().length, source); diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java index 25231748149f..ca84ba3eb6f2 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java @@ -193,7 +193,7 @@ public KeyRangeIterator search(QueryContext context, Expression expr, AbstractBo float threshold = expr.getEuclideanSearchThreshold(); SortingIterator.Builder keyQueue; - try (var pkIterator = searchInternal(context, qv, keyRange, graph.size(), threshold)) + try (var pkIterator = searchInternal(context, qv, keyRange, graph.size(), graph.size(), threshold)) { keyQueue = new SortingIterator.Builder<>(); while (pkIterator.hasNext()) @@ -223,14 +223,16 @@ public List> orderBy(QueryContext conte assert orderer.isANN() : "Only ANN is supported for vector search, received " + orderer.operator; var qv = vts.createFloatVector(orderer.vector); + var rerankK = orderer.rerankKFor(limit, VectorCompression.NO_COMPRESSION); - return List.of(searchInternal(context, qv, keyRange, limit, 0)); + return List.of(searchInternal(context, qv, keyRange, limit, rerankK, 0)); } private CloseableIterator searchInternal(QueryContext context, VectorFloat queryVector, AbstractBounds keyRange, int limit, + int rerankK, float threshold) { Bits bits; @@ -272,7 +274,7 @@ private CloseableIterator searchInternal(QueryContext con bits = new KeyRangeFilteringBits(keyRange); } - var nodeScoreIterator = graph.search(context, queryVector, limit, threshold, bits); + var nodeScoreIterator = graph.search(context, queryVector, limit, rerankK, threshold, bits); return new NodeScoreToScoredPrimaryKeyIterator(nodeScoreIterator); } @@ -305,6 +307,7 @@ public CloseableIterator orderResultsBy(QueryContext cont relevantOrdinals.add(i); }); + int rerankK = orderer.rerankKFor(limit, VectorCompression.NO_COMPRESSION); int maxBruteForceRows = maxBruteForceRows(limit, relevantOrdinals.size(), graph.size()); Tracing.logAndTrace(logger, "{} rows relevant to current memtable out of {} materialized by SAI; max brute force rows is {} for memtable index with {} nodes, LIMIT {}", relevantOrdinals.size(), keys.size(), maxBruteForceRows, graph.size(), limit); @@ -319,7 +322,7 @@ public CloseableIterator orderResultsBy(QueryContext cont return orderByBruteForce(qv, keysInGraph); } // indexed path - var nodeScoreIterator = graph.search(context, qv, limit, 0, relevantOrdinals::contains); + var nodeScoreIterator = graph.search(context, qv, limit, rerankK, 0, relevantOrdinals::contains); return new NodeScoreToScoredPrimaryKeyIterator(nodeScoreIterator); } diff --git a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java index e23c83ed91e2..8eba8f43b537 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java @@ -31,6 +31,7 @@ import org.apache.cassandra.index.SecondaryIndexManager; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.StorageAttachedIndex; +import org.apache.cassandra.index.sai.disk.vector.VectorCompression; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; import org.apache.cassandra.index.sai.utils.TypeUtil; @@ -46,20 +47,25 @@ public class Orderer public final IndexContext context; public final Operator operator; + + // Vector search parameters public final float[] vector; + private final Integer rerankK; /** * Create an orderer for the given index context, operator, and term. * @param context the index context, used to build the view of memtables and sstables for query execution. * @param operator the operator for the order by clause. * @param term the term to order by (not always relevant) + * @param rerankK optional rerank K parameter for ANN queries */ - public Orderer(IndexContext context, Operator operator, ByteBuffer term) + public Orderer(IndexContext context, Operator operator, ByteBuffer term, Integer rerankK) { this.context = context; assert ORDER_BY_OPERATORS.contains(operator) : "Invalid operator for order by clause " + operator; this.operator = operator; this.vector = context.getValidator().isVector() ? TypeUtil.decomposeVector(context.getValidator(), term) : null; + this.rerankK = rerankK; } public String getIndexName() @@ -89,6 +95,22 @@ public boolean isANN() return operator == Operator.ANN; } + /** + * Provide rerankK for ANN queries. Use the user provided rerankK if available, otherwise use the model's default + * based on the limit and compression type. + * @param limit the query limit or the proportional segment limit to use when calculating a reasonable rerankK + * default value + * @param vc the compression type of the vectors in the index + * @return the rerankK value to use in ANN search + */ + public int rerankKFor(int limit, VectorCompression vc) + { + assert isANN() : "rerankK is only valid for ANN queries"; + return rerankK != null + ? rerankK + : context.getIndexWriterConfig().getSourceModel().rerankKFor(limit, vc); + } + @Nullable public static Orderer from(SecondaryIndexManager indexManager, RowFilter filter) { @@ -98,7 +120,10 @@ public static Orderer from(SecondaryIndexManager indexManager, RowFilter filter) var orderRowFilter = expressions.get(0); var index = indexManager.getBestIndexFor(orderRowFilter, StorageAttachedIndex.class) .orElseThrow(() -> new IllegalStateException("No index found for order by clause")); - return new Orderer(index.getIndexContext(), orderRowFilter.operator(), orderRowFilter.getIndexValue()); + + // Null if not specified explicitly in the CQL query. + Integer rerankK = filter.annOptions().rerankK; + return new Orderer(index.getIndexContext(), orderRowFilter.operator(), orderRowFilter.getIndexValue(), rerankK); } public static boolean isFilterExpressionOrderer(RowFilter.Expression expression) @@ -110,8 +135,9 @@ public static boolean isFilterExpressionOrderer(RowFilter.Expression expression) public String toString() { String direction = isAscending() ? "ASC" : "DESC"; + String rerankInfo = rerankK != null ? String.format(" (rerank_k=%d)", rerankK) : ""; return isANN() - ? context.getColumnName() + " ANN OF " + Arrays.toString(vector) + ' ' + direction + ? context.getColumnName() + " ANN OF " + Arrays.toString(vector) + ' ' + direction + rerankInfo : context.getColumnName() + ' ' + direction; } } diff --git a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java index 15b0965fe82e..d87d763141c0 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java +++ b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java @@ -937,11 +937,11 @@ private long estimateMatchingRowCountUsingIndex(Expression predicate) @Override - public double estimateAnnSearchCost(Orderer ordering, int limit, long candidates) + public double estimateAnnSearchCost(Orderer orderer, int limit, long candidates) { Preconditions.checkArgument(limit > 0, "limit must be > 0"); - IndexContext context = ordering.context; + IndexContext context = orderer.context; Collection memtables = context.getLiveMemtables().values(); View queryView = context.getView(); @@ -965,7 +965,7 @@ public double estimateAnnSearchCost(Orderer ordering, int limit, long candidates continue; int segmentLimit = segment.proportionalAnnLimit(limit, totalRows); int segmentCandidates = max(1, (int) (candidates * (double) segment.metadata.numRows / totalRows)); - cost += segment.estimateAnnSearchCost(segmentLimit, segmentCandidates); + cost += segment.estimateAnnSearchCost(orderer, segmentLimit, segmentCandidates); } } return cost; diff --git a/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java b/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java index 34f3d6b6a321..c66f6889a7a3 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java @@ -68,9 +68,38 @@ public void testSiftSmall() throws Throwable double memoryRecall = testRecall(100, queryVectors, groundTruth); assertTrue("Memory recall is " + memoryRecall, memoryRecall > 0.975); + // Run a few queries with increasing rerank_k to validate that recall increases + ensureIncreasingRerankKIncreasesRecall(queryVectors, groundTruth); + flush(); var diskRecall = testRecall(100, queryVectors, groundTruth); assertTrue("Disk recall is " + diskRecall, diskRecall > 0.975); + + // Run a few queries with increasing rerank_k to validate that recall increases + ensureIncreasingRerankKIncreasesRecall(queryVectors, groundTruth); + } + + private void ensureIncreasingRerankKIncreasesRecall(List queryVectors, List> groundTruth) + { + // Validate that the recall increases as we increase the rerank_k parameter + double previousRecall = 0; + int limit = 10; + int strictlyIncreasedCount = 0; + // Testing shows that we acheive 100% recall at about rerank_k = 45, so no need to go higher + for (int rerankK = limit; rerankK <= 50; rerankK += 5) + { + var recall = testRecall(limit, queryVectors, groundTruth, rerankK); + // Recall varies, so we can only assert that it does not get worse on a per-run basis. However, it should + // get better strictly at least some of the time + assertTrue("Recall for rerank_k = " + rerankK + " is " + recall, recall >= previousRecall); + if (recall > previousRecall) + strictlyIncreasedCount++; + previousRecall = recall; + } + // This is a conservative assertion to prevent it from being too fragile. At the time of writing this test, + // we observed a strict increase of 6 times for in memory and 5 times for on disk. + assertTrue("Recall should have strictly increased at least 4 times but only increased " + strictlyIncreasedCount + " times", + strictlyIncreasedCount > 3); } @Test @@ -198,6 +227,11 @@ private static ArrayList> readIvecs(String filename) } public double testRecall(int topK, List queryVectors, List> groundTruth) + { + return testRecall(topK, queryVectors, groundTruth, null); + } + + public double testRecall(int topK, List queryVectors, List> groundTruth, Integer rerankK) { AtomicInteger topKfound = new AtomicInteger(0); @@ -208,7 +242,11 @@ public double testRecall(int topK, List queryVectors, List Date: Tue, 11 Feb 2025 15:28:55 -0600 Subject: [PATCH 02/12] Replace limit with rerankK for cost estimates Before this, we had minor inconcistencies where the in memory and the on disk indexes diverged. This might have been intentional, but I'm not certain, so changing this here to force the discussion in the PR review. Now, everything is consistently rerankK when estimating the cost to search the graph in order to determine if brute force is cheaper. --- .../sai/disk/v2/V2VectorIndexSearcher.java | 4 +-- .../sai/disk/vector/VectorMemtableIndex.java | 32 +++++++++---------- .../index/sai/plan/QueryController.java | 4 ++- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java index d150fb1ea742..7b9057f56834 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java @@ -610,9 +610,9 @@ public static double logBase2(double number) { return Math.log(number) / Math.log(2); } - private int getRawExpectedNodes(int limit, int nPermittedOrdinals) + private int getRawExpectedNodes(int rerankK, int nPermittedOrdinals) { - return VectorMemtableIndex.expectedNodesVisited(limit, nPermittedOrdinals, graph.size()); + return VectorMemtableIndex.expectedNodesVisited(rerankK, nPermittedOrdinals, graph.size()); } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java index ca84ba3eb6f2..a8d85325e4a8 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java @@ -258,11 +258,11 @@ private CloseableIterator searchInternal(QueryContext con if (resultKeys.isEmpty()) return CloseableIterator.emptyIterator(); - int bruteForceRows = maxBruteForceRows(limit, resultKeys.size(), graph.size()); - logger.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, LIMIT {}", - resultKeys.size(), bruteForceRows, graph.size(), limit); - Tracing.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, LIMIT {}", - resultKeys.size(), bruteForceRows, graph.size(), limit); + int bruteForceRows = maxBruteForceRows(rerankK, resultKeys.size(), graph.size()); + logger.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, rerankK {}, LIMIT {}", + resultKeys.size(), bruteForceRows, graph.size(), rerankK, limit); + Tracing.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, rerankK {}, LIMIT {}", + resultKeys.size(), bruteForceRows, graph.size(), rerankK, limit); if (resultKeys.size() <= bruteForceRows) // When we have a threshold, we only need to filter the results, not order them, because it means we're // evaluating a boolean predicate in the SAI pipeline that wants to collate by PK @@ -308,9 +308,9 @@ public CloseableIterator orderResultsBy(QueryContext cont }); int rerankK = orderer.rerankKFor(limit, VectorCompression.NO_COMPRESSION); - int maxBruteForceRows = maxBruteForceRows(limit, relevantOrdinals.size(), graph.size()); - Tracing.logAndTrace(logger, "{} rows relevant to current memtable out of {} materialized by SAI; max brute force rows is {} for memtable index with {} nodes, LIMIT {}", - relevantOrdinals.size(), keys.size(), maxBruteForceRows, graph.size(), limit); + int maxBruteForceRows = maxBruteForceRows(rerankK, relevantOrdinals.size(), graph.size()); + Tracing.logAndTrace(logger, "{} rows relevant to current memtable out of {} materialized by SAI; max brute force rows is {} for memtable index with {} nodes, rerankK {}", + relevantOrdinals.size(), keys.size(), maxBruteForceRows, graph.size(), rerankK); // convert the expression value to query vector var qv = vts.createFloatVector(orderer.vector); @@ -378,15 +378,15 @@ private PrimaryKeyWithScore scoreKey(VectorSimilarityFunction similarityFunction return new PrimaryKeyWithScore(indexContext, mt, key, score); } - private int maxBruteForceRows(int limit, int nPermittedOrdinals, int graphSize) + private int maxBruteForceRows(int rerankK, int nPermittedOrdinals, int graphSize) { - int expectedNodesVisited = expectedNodesVisited(limit, nPermittedOrdinals, graphSize); - return min(max(limit, expectedNodesVisited), GLOBAL_BRUTE_FORCE_ROWS); + int expectedNodesVisited = expectedNodesVisited(rerankK, nPermittedOrdinals, graphSize); + return min(max(rerankK, expectedNodesVisited), GLOBAL_BRUTE_FORCE_ROWS); } - public int estimateAnnNodesVisited(int limit, int nPermittedOrdinals) + public int estimateAnnNodesVisited(int rerankK, int nPermittedOrdinals) { - return expectedNodesVisited(limit, nPermittedOrdinals, graph.size()); + return expectedNodesVisited(rerankK, nPermittedOrdinals, graph.size()); } /** @@ -398,9 +398,9 @@ public int estimateAnnNodesVisited(int limit, int nPermittedOrdinals) * !!! roughly `degree` times larger than the number of nodes whose edge lists we load! * !!! */ - public static int expectedNodesVisited(int limit, int nPermittedOrdinals, int graphSize) + public static int expectedNodesVisited(int rerankK, int nPermittedOrdinals, int graphSize) { - var K = limit; + var K = rerankK; var B = min(nPermittedOrdinals, graphSize); var N = graphSize; // These constants come from running many searches on a variety of datasets and graph sizes. @@ -415,7 +415,7 @@ public static int expectedNodesVisited(int limit, int nPermittedOrdinals, int gr // If we need to make this even more accurate, the relationship to B and to log(N) may be the best // places to start. var raw = (int) (100 + 0.025 * pow(log(N), 2) * pow(K, 0.95) * ((double) N / B)); - return ensureSaneEstimate(raw, limit, graphSize); + return ensureSaneEstimate(raw, rerankK, graphSize); } public static int ensureSaneEstimate(int rawEstimate, int rerankK, int graphSize) diff --git a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java index d87d763141c0..dd70b4296b15 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java +++ b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java @@ -64,6 +64,7 @@ import org.apache.cassandra.index.sai.StorageAttachedIndex; import org.apache.cassandra.index.sai.disk.format.IndexFeatureSet; import org.apache.cassandra.index.sai.disk.v1.Segment; +import org.apache.cassandra.index.sai.disk.vector.VectorCompression; import org.apache.cassandra.index.sai.disk.vector.VectorMemtableIndex; import org.apache.cassandra.index.sai.iterators.KeyRangeIntersectionIterator; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; @@ -945,12 +946,13 @@ public double estimateAnnSearchCost(Orderer orderer, int limit, long candidates) Collection memtables = context.getLiveMemtables().values(); View queryView = context.getView(); + int memoryRerankK = orderer.rerankKFor(limit, VectorCompression.NO_COMPRESSION); double cost = 0; for (MemtableIndex index : memtables) { // FIXME convert nodes visited to search cost int memtableCandidates = (int) Math.min(Integer.MAX_VALUE, candidates); - cost += ((VectorMemtableIndex) index).estimateAnnNodesVisited(limit, memtableCandidates); + cost += ((VectorMemtableIndex) index).estimateAnnNodesVisited(memoryRerankK, memtableCandidates); } long totalRows = 0; From a676e0337fa942282caa1e5da410232b86b163ca Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Tue, 11 Feb 2025 16:21:40 -0600 Subject: [PATCH 03/12] Add guardrails; fail at 4x vector max_top_k --- .../cassandra/cql3/statements/SelectOptions.java | 6 ++++-- .../cassandra/cql3/statements/SelectStatement.java | 2 +- .../org/apache/cassandra/db/filter/ANNOptions.java | 11 +++++++++-- .../org/apache/cassandra/guardrails/Guardrails.java | 11 +++++++++++ .../cassandra/guardrails/GuardrailsConfig.java | 12 ++++++++++++ .../apache/cassandra/db/filter/ANNOptionsTest.java | 10 ++++++++++ 6 files changed, 47 insertions(+), 5 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectOptions.java b/src/java/org/apache/cassandra/cql3/statements/SelectOptions.java index 6c7c8f2e631d..4eada73aa563 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectOptions.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectOptions.java @@ -22,6 +22,7 @@ import org.apache.cassandra.db.filter.ANNOptions; import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.exceptions.RequestValidationException; +import org.apache.cassandra.service.QueryState; /** * {@code WITH option1=... AND option2=...} options for SELECT statements. @@ -36,13 +37,14 @@ public class SelectOptions extends PropertyDefinitions /** * Validates all the {@code SELECT} options. * + * @param state the query state * @param limit the {@code SELECT} query user-provided limit * @throws InvalidRequestException if any of the options are invalid */ - public void validate(int limit) throws RequestValidationException + public void validate(QueryState state, int limit) throws RequestValidationException { validate(keywords, Collections.emptySet()); - parseANNOptions().validate(limit); + parseANNOptions().validate(state, limit); } /** diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index 58b75aacab5a..c90f23feaddf 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -426,7 +426,7 @@ public ReadQuery getQuery(QueryState queryState, checkFalse(userOffset != NO_OFFSET, String.format(TOPK_OFFSET_ERROR, userOffset)); } - selectOptions.validate(userLimit); + selectOptions.validate(queryState, userLimit); return query; } diff --git a/src/java/org/apache/cassandra/db/filter/ANNOptions.java b/src/java/org/apache/cassandra/db/filter/ANNOptions.java index ea74ff6f1136..4ce213c64d26 100644 --- a/src/java/org/apache/cassandra/db/filter/ANNOptions.java +++ b/src/java/org/apache/cassandra/db/filter/ANNOptions.java @@ -24,10 +24,12 @@ import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.exceptions.InvalidRequestException; +import org.apache.cassandra.guardrails.Guardrails; import org.apache.cassandra.io.util.DataInputPlus; import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.service.QueryState; import org.apache.cassandra.utils.FBUtilities; /** @@ -60,10 +62,15 @@ public static ANNOptions create(@Nullable Integer rerankK) return rerankK == null ? NONE : new ANNOptions(rerankK); } - public void validate(int limit) + public void validate(QueryState state, int limit) { - if (rerankK != null && rerankK > 0 && rerankK < limit) + if (rerankK == null || rerankK <= 0) + return; + + if (rerankK < limit) throw new InvalidRequestException(String.format("Invalid rerank_k value %d lesser than limit %d", rerankK, limit)); + + Guardrails.annRerankKMaxValue.guard(rerankK, "ANN Option rerank_k", false, state); } /** diff --git a/src/java/org/apache/cassandra/guardrails/Guardrails.java b/src/java/org/apache/cassandra/guardrails/Guardrails.java index d02943c74385..b00b59ab6664 100644 --- a/src/java/org/apache/cassandra/guardrails/Guardrails.java +++ b/src/java/org/apache/cassandra/guardrails/Guardrails.java @@ -120,6 +120,17 @@ what, formatSize(v), formatSize(t))) format("%s has a vector of %s dimensions, this exceeds the %s threshold of %s.", what, value, isWarning ? "warning" : "failure", threshold)); + /** + * Guardrail on the maximum value for the rerank_k parameter, an ANN query option. + */ + public static final Threshold annRerankKMaxValue = + factory.threshold("sai_ann_rerank_k_max_value", + () -> config.sai_ann_rerank_k_warn_threshold, + () -> config.sai_ann_rerank_k_failure_threshold, + (isWarning, what, value, threshold) -> + format("%s specifies rerank_k=%s, this exceeds the %s threshold of %s.", + what, value, isWarning ? "warning" : "failure", threshold)); + public static final DisableFlag readBeforeWriteListOperationsEnabled = factory.disableFlag("read_before_write_list_operations", () -> !config.read_before_write_list_operations_enabled, diff --git a/src/java/org/apache/cassandra/guardrails/GuardrailsConfig.java b/src/java/org/apache/cassandra/guardrails/GuardrailsConfig.java index 015dad51ee82..9b4ee1160c7a 100644 --- a/src/java/org/apache/cassandra/guardrails/GuardrailsConfig.java +++ b/src/java/org/apache/cassandra/guardrails/GuardrailsConfig.java @@ -29,6 +29,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; +import org.apache.cassandra.config.CassandraRelevantProperties; import org.apache.cassandra.config.Config; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.cql3.statements.schema.TableAttributes; @@ -81,6 +82,8 @@ public class GuardrailsConfig public volatile Boolean read_before_write_list_operations_enabled; public volatile Integer vector_dimensions_warn_threshold; public volatile Integer vector_dimensions_failure_threshold; + public volatile Integer sai_ann_rerank_k_warn_threshold; + public volatile Integer sai_ann_rerank_k_failure_threshold; // Legacy 2i guardrail public volatile Integer secondary_index_per_table_failure_threshold; @@ -165,6 +168,11 @@ public void applyConfig() enforceDefault(tombstone_warn_threshold, v -> tombstone_warn_threshold = v, 1000, 1000); enforceDefault(tombstone_failure_threshold, v -> tombstone_failure_threshold = v, 100000, 100000); + // Default to no warning and failure at 4 times the maxTopK value + int maxTopK = CassandraRelevantProperties.SAI_VECTOR_SEARCH_MAX_TOP_K.getInt(); + enforceDefault(sai_ann_rerank_k_warn_threshold, v -> sai_ann_rerank_k_warn_threshold = v, -1, -1); + enforceDefault(sai_ann_rerank_k_failure_threshold, v -> sai_ann_rerank_k_failure_threshold = v, 4 * maxTopK, 4 * maxTopK); + // for write requests enforceDefault(logged_batch_enabled, v -> logged_batch_enabled = v, true, true); enforceDefault(batch_size_warn_threshold_in_kb, v -> batch_size_warn_threshold_in_kb = v, 64, 64); @@ -269,6 +277,10 @@ public void validate() validateStrictlyPositiveInteger(vector_dimensions_failure_threshold, "vector_dimensions_failure_threshold"); validateWarnLowerThanFail(vector_dimensions_warn_threshold, vector_dimensions_failure_threshold, "vector_dimensions"); + validateStrictlyPositiveInteger(sai_ann_rerank_k_warn_threshold, "sai_ann_rerank_k_warn_threshold"); + validateStrictlyPositiveInteger(sai_ann_rerank_k_failure_threshold, "sai_ann_rerank_k_failure_threshold"); + validateWarnLowerThanFail(sai_ann_rerank_k_warn_threshold, sai_ann_rerank_k_failure_threshold, "sai_ann_rerank_k"); + validateStrictlyPositiveInteger(tables_warn_threshold, "tables_warn_threshold"); validateStrictlyPositiveInteger(tables_failure_threshold, "tables_failure_threshold"); validateWarnLowerThanFail(tables_warn_threshold, tables_failure_threshold, "tables"); diff --git a/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java b/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java index 791068fab39c..7f5a94b20229 100644 --- a/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java +++ b/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java @@ -17,12 +17,14 @@ import java.io.IOException; import java.util.Objects; +import java.util.Optional; import javax.annotation.Nullable; import org.junit.BeforeClass; import org.junit.Test; +import com.datastax.driver.core.exceptions.InvalidQueryException; import org.apache.cassandra.config.CassandraRelevantProperties; import org.apache.cassandra.cql3.CQLTester; import org.apache.cassandra.cql3.Operator; @@ -43,6 +45,7 @@ import org.apache.cassandra.net.MessagingService; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.IndexMetadata; +import org.apache.cassandra.transport.ProtocolVersion; import org.apache.cassandra.utils.Pair; import org.assertj.core.api.Assertions; import org.quicktheories.QuickTheory; @@ -95,6 +98,13 @@ public void testParseAndValidate() execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': '-1'}"); execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': '-1000'}"); + // Queries that exeed the failure threshold for the guardrail. Specifies a protocol version to trigger + // validation in the coordinator. + assertInvalidThrowMessage(Optional.of(ProtocolVersion.V5), + "ANN Option rerank_k specifies rerank_k=5000, this exceeds the failure threshold of 4000.", + InvalidQueryException.class, + "SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 5000}"); + String baseQuery = "SELECT * FROM %s ORDER BY v ANN OF [1, 1]"; // unknown SELECT options From ae3958d42d90c5b7778e3f1175f4f8ce28c77e21 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Wed, 12 Feb 2025 09:36:53 -0600 Subject: [PATCH 04/12] Fix ANNOptionsDistributedTest --- .../test/sai/ANNOptionsDistributedTest.java | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/test/distributed/org/apache/cassandra/distributed/test/sai/ANNOptionsDistributedTest.java b/test/distributed/org/apache/cassandra/distributed/test/sai/ANNOptionsDistributedTest.java index 26052606a5e6..68de1aa12e3d 100644 --- a/test/distributed/org/apache/cassandra/distributed/test/sai/ANNOptionsDistributedTest.java +++ b/test/distributed/org/apache/cassandra/distributed/test/sai/ANNOptionsDistributedTest.java @@ -54,7 +54,8 @@ public void testANNOptionsWithAllDS11() throws Throwable .withConfig(config -> config.with(GOSSIP).with(NETWORK)) .start(), RF)) { - test(cluster, "SAI doesn't support ANN options yet."); + // null indicates that the query should succeed + testSelectWithAnnOptions(cluster, null); } } @@ -70,7 +71,7 @@ public void testANNOptionsWithAllDS10() throws Throwable .withConfig(config -> config.with(GOSSIP).with(NETWORK)) .start(), RF)) { - test(cluster, "ANN options are not supported in clusters below DS 11."); + testSelectWithAnnOptions(cluster, "ANN options are not supported in clusters below DS 11."); } } @@ -80,16 +81,18 @@ public void testANNOptionsWithAllDS10() throws Throwable @Test public void testANNOptionsWithMixedDS10AndDS11() throws Throwable { + assert CassandraRelevantProperties.DS_CURRENT_MESSAGING_VERSION.getInt() >= MessagingService.VERSION_DS_11; + try (Cluster cluster = init(Cluster.build(NUM_REPLICAS) .withInstanceInitializer(BB::install) .withConfig(config -> config.with(GOSSIP).with(NETWORK).with(NATIVE_PROTOCOL)) .start(), RF)) { - test(cluster, "ANN options are not supported in clusters below DS 11."); + testSelectWithAnnOptions(cluster, "ANN options are not supported in clusters below DS 11."); } } - private static void test(Cluster cluster, String expectedErrorMessage) + private static void testSelectWithAnnOptions(Cluster cluster, String expectedErrorMessage) { cluster.schemaChange(withKeyspace("CREATE TABLE %s.t (k int PRIMARY KEY, n int, v vector)")); cluster.schemaChange(withKeyspace("CREATE CUSTOM INDEX ON %s.t(v) USING 'StorageAttachedIndex'")); @@ -100,13 +103,16 @@ private static void test(Cluster cluster, String expectedErrorMessage) for (int i = 1; i <= cluster.size(); i++) { ICoordinator coordinator = cluster.coordinator(i); - Assertions.assertThatThrownBy(() -> coordinator.execute(select, ConsistencyLevel.ONE)) - .hasMessageContaining(expectedErrorMessage); + if (expectedErrorMessage == null) + coordinator.execute(select, ConsistencyLevel.ONE); + else + Assertions.assertThatThrownBy(() -> coordinator.execute(select, ConsistencyLevel.ONE)) + .hasMessageContaining(expectedErrorMessage); } } /** - * Injection to set the current version of the first cluster node to DS 11. + * Injection to set the current version of the first cluster node to DS 10. */ public static class BB { @@ -126,7 +132,7 @@ public static void install(ClassLoader classLoader, int node) @SuppressWarnings("unused") public static int currentVersion() { - return MessagingService.VERSION_DS_11; + return MessagingService.VERSION_DS_10; } } } From de766aab528d281e1cf6ad1df52d5d7b8c5d9aed Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Wed, 12 Feb 2025 09:44:00 -0600 Subject: [PATCH 05/12] Minor cleanup from code review feedback --- src/java/org/apache/cassandra/db/filter/ANNOptions.java | 2 +- .../org/apache/cassandra/index/sai/plan/Orderer.java | 9 +++++---- .../org/apache/cassandra/db/filter/ANNOptionsTest.java | 4 ++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/java/org/apache/cassandra/db/filter/ANNOptions.java b/src/java/org/apache/cassandra/db/filter/ANNOptions.java index 4ce213c64d26..b9420610b8ab 100644 --- a/src/java/org/apache/cassandra/db/filter/ANNOptions.java +++ b/src/java/org/apache/cassandra/db/filter/ANNOptions.java @@ -70,7 +70,7 @@ public void validate(QueryState state, int limit) if (rerankK < limit) throw new InvalidRequestException(String.format("Invalid rerank_k value %d lesser than limit %d", rerankK, limit)); - Guardrails.annRerankKMaxValue.guard(rerankK, "ANN Option rerank_k", false, state); + Guardrails.annRerankKMaxValue.guard(rerankK, "ANN options", false, state); } /** diff --git a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java index 8eba8f43b537..76d7d0e92da4 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java @@ -59,7 +59,7 @@ public class Orderer * @param term the term to order by (not always relevant) * @param rerankK optional rerank K parameter for ANN queries */ - public Orderer(IndexContext context, Operator operator, ByteBuffer term, Integer rerankK) + public Orderer(IndexContext context, Operator operator, ByteBuffer term, @Nullable Integer rerankK) { this.context = context; assert ORDER_BY_OPERATORS.contains(operator) : "Invalid operator for order by clause " + operator; @@ -98,6 +98,7 @@ public boolean isANN() /** * Provide rerankK for ANN queries. Use the user provided rerankK if available, otherwise use the model's default * based on the limit and compression type. + * * @param limit the query limit or the proportional segment limit to use when calculating a reasonable rerankK * default value * @param vc the compression type of the vectors in the index @@ -117,13 +118,13 @@ public static Orderer from(SecondaryIndexManager indexManager, RowFilter filter) var expressions = filter.root().expressions().stream().filter(Orderer::isFilterExpressionOrderer).collect(Collectors.toList()); if (expressions.isEmpty()) return null; - var orderRowFilter = expressions.get(0); - var index = indexManager.getBestIndexFor(orderRowFilter, StorageAttachedIndex.class) + var orderExpression = expressions.get(0); + var index = indexManager.getBestIndexFor(orderExpression, StorageAttachedIndex.class) .orElseThrow(() -> new IllegalStateException("No index found for order by clause")); // Null if not specified explicitly in the CQL query. Integer rerankK = filter.annOptions().rerankK; - return new Orderer(index.getIndexContext(), orderRowFilter.operator(), orderRowFilter.getIndexValue(), rerankK); + return new Orderer(index.getIndexContext(), orderExpression.operator(), orderExpression.getIndexValue(), rerankK); } public static boolean isFilterExpressionOrderer(RowFilter.Expression expression) diff --git a/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java b/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java index 7f5a94b20229..1b3c12f316ff 100644 --- a/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java +++ b/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java @@ -98,10 +98,10 @@ public void testParseAndValidate() execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': '-1'}"); execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': '-1000'}"); - // Queries that exeed the failure threshold for the guardrail. Specifies a protocol version to trigger + // Queries that exceed the failure threshold for the guardrail. Specifies a protocol version to trigger // validation in the coordinator. assertInvalidThrowMessage(Optional.of(ProtocolVersion.V5), - "ANN Option rerank_k specifies rerank_k=5000, this exceeds the failure threshold of 4000.", + "ANN options specifies rerank_k=5000, this exceeds the failure threshold of 4000.", InvalidQueryException.class, "SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 5000}"); From b3b812c360bf9415e34e5c7476d9d21eb3a20369 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Wed, 12 Feb 2025 10:05:45 -0600 Subject: [PATCH 06/12] Add GuardrailSAIAnnRerankKTest --- .../GuardrailSAIAnnRerankKTest.java | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java diff --git a/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java b/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java new file mode 100644 index 000000000000..463a35238a51 --- /dev/null +++ b/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.guardrails; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.cassandra.config.CassandraRelevantProperties; +import org.apache.cassandra.config.DatabaseDescriptor; + +import static org.junit.Assert.assertEquals; + +public class GuardrailSAIAnnRerankKTest extends GuardrailTester +{ + private static final int WARN_THRESHOLD = 50; + private static final int FAIL_THRESHOLD = 100; + + private int defaultWarnThreshold; + private int defaultFailThreshold; + private int defaultMaxTopK; + + @Before + public void before() + { + defaultWarnThreshold = DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold; + defaultFailThreshold = DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_failure_threshold; + defaultMaxTopK = CassandraRelevantProperties.SAI_VECTOR_SEARCH_MAX_TOP_K.getInt(); + + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold = WARN_THRESHOLD; + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_failure_threshold = FAIL_THRESHOLD; + } + + @After + public void after() + { + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold = defaultWarnThreshold; + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_failure_threshold = defaultFailThreshold; + CassandraRelevantProperties.SAI_VECTOR_SEARCH_MAX_TOP_K.setInt(defaultMaxTopK); + } + + @Test + public void testConfigValidation() + { + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_failure_threshold = -1; + testValidationOfStrictlyPositiveProperty((c, v) -> c.sai_ann_rerank_k_warn_threshold = v.intValue(), + "sai_ann_rerank_k_warn_threshold"); + + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold = -1; + testValidationOfStrictlyPositiveProperty((c, v) -> c.sai_ann_rerank_k_failure_threshold = v.intValue(), + "sai_ann_rerank_k_failure_threshold"); + } + + @Test + public void testDefaultValues() + { + // Reset to defaults + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold = defaultWarnThreshold; + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_failure_threshold = defaultFailThreshold; + + // Test that default failure threshold is 4 times the max top K + int maxTopK = CassandraRelevantProperties.SAI_VECTOR_SEARCH_MAX_TOP_K.getInt(); + assertEquals(-1, (int) DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold); + assertEquals(4 * maxTopK, (int) DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_failure_threshold); + } + + @Test + public void testSAIAnnRerankKThresholds() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + // Test values below and at warning threshold + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': 10}"); + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (WARN_THRESHOLD - 1) + '}'); + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (WARN_THRESHOLD) + '}'); + + // Test values between warning and failure thresholds + assertWarns(String.format("ANN options specifies rerank_k=%d, this exceeds the warning threshold of %d.", + WARN_THRESHOLD + 1, WARN_THRESHOLD), + "SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (WARN_THRESHOLD + 1) + '}'); + + // Test values at failure threshold (should still warn) + assertWarns(String.format("ANN options specifies rerank_k=%d, this exceeds the warning threshold of %d.", + FAIL_THRESHOLD, WARN_THRESHOLD), + "SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + FAIL_THRESHOLD + '}'); + + // Test values above failure threshold + assertFails(String.format("ANN options specifies rerank_k=%d, this exceeds the failure threshold of %d.", + FAIL_THRESHOLD + 1, FAIL_THRESHOLD), + "SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (FAIL_THRESHOLD + 1) + '}'); + } + + @Test + public void testDisabledThresholds() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + // Test with warning threshold disabled + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold = -1; + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (WARN_THRESHOLD + 1) + '}'); + + // Test with failure threshold disabled + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_failure_threshold = -1; + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (FAIL_THRESHOLD + 1) + '}'); + } + + @Test + public void testNegativeRerankK() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + // Negative rerank_k values should be valid and not trigger warnings + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': -1}"); + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': -1000}"); + } + + @Test + public void testMissingRerankK() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v vector)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + // Queries without rerank_k should be valid + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10"); + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {}"); + } +} \ No newline at end of file From 33c7b8a481d242019d38f607cef5e8fa76acf3e7 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Wed, 12 Feb 2025 11:14:54 -0600 Subject: [PATCH 07/12] Replace failure with fail to match 5.0 convention --- .../cassandra/guardrails/Guardrails.java | 2 +- .../cassandra/guardrails/GuardrailsConfig.java | 8 ++++---- .../guardrails/GuardrailSAIAnnRerankKTest.java | 18 +++++++++--------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/java/org/apache/cassandra/guardrails/Guardrails.java b/src/java/org/apache/cassandra/guardrails/Guardrails.java index b00b59ab6664..845a00c2fd86 100644 --- a/src/java/org/apache/cassandra/guardrails/Guardrails.java +++ b/src/java/org/apache/cassandra/guardrails/Guardrails.java @@ -126,7 +126,7 @@ what, formatSize(v), formatSize(t))) public static final Threshold annRerankKMaxValue = factory.threshold("sai_ann_rerank_k_max_value", () -> config.sai_ann_rerank_k_warn_threshold, - () -> config.sai_ann_rerank_k_failure_threshold, + () -> config.sai_ann_rerank_k_fail_threshold, (isWarning, what, value, threshold) -> format("%s specifies rerank_k=%s, this exceeds the %s threshold of %s.", what, value, isWarning ? "warning" : "failure", threshold)); diff --git a/src/java/org/apache/cassandra/guardrails/GuardrailsConfig.java b/src/java/org/apache/cassandra/guardrails/GuardrailsConfig.java index 9b4ee1160c7a..2338de2ca3a3 100644 --- a/src/java/org/apache/cassandra/guardrails/GuardrailsConfig.java +++ b/src/java/org/apache/cassandra/guardrails/GuardrailsConfig.java @@ -83,7 +83,7 @@ public class GuardrailsConfig public volatile Integer vector_dimensions_warn_threshold; public volatile Integer vector_dimensions_failure_threshold; public volatile Integer sai_ann_rerank_k_warn_threshold; - public volatile Integer sai_ann_rerank_k_failure_threshold; + public volatile Integer sai_ann_rerank_k_fail_threshold; // Legacy 2i guardrail public volatile Integer secondary_index_per_table_failure_threshold; @@ -171,7 +171,7 @@ public void applyConfig() // Default to no warning and failure at 4 times the maxTopK value int maxTopK = CassandraRelevantProperties.SAI_VECTOR_SEARCH_MAX_TOP_K.getInt(); enforceDefault(sai_ann_rerank_k_warn_threshold, v -> sai_ann_rerank_k_warn_threshold = v, -1, -1); - enforceDefault(sai_ann_rerank_k_failure_threshold, v -> sai_ann_rerank_k_failure_threshold = v, 4 * maxTopK, 4 * maxTopK); + enforceDefault(sai_ann_rerank_k_fail_threshold, v -> sai_ann_rerank_k_fail_threshold = v, 4 * maxTopK, 4 * maxTopK); // for write requests enforceDefault(logged_batch_enabled, v -> logged_batch_enabled = v, true, true); @@ -278,8 +278,8 @@ public void validate() validateWarnLowerThanFail(vector_dimensions_warn_threshold, vector_dimensions_failure_threshold, "vector_dimensions"); validateStrictlyPositiveInteger(sai_ann_rerank_k_warn_threshold, "sai_ann_rerank_k_warn_threshold"); - validateStrictlyPositiveInteger(sai_ann_rerank_k_failure_threshold, "sai_ann_rerank_k_failure_threshold"); - validateWarnLowerThanFail(sai_ann_rerank_k_warn_threshold, sai_ann_rerank_k_failure_threshold, "sai_ann_rerank_k"); + validateStrictlyPositiveInteger(sai_ann_rerank_k_fail_threshold, "sai_ann_rerank_k_fail_threshold"); + validateWarnLowerThanFail(sai_ann_rerank_k_warn_threshold, sai_ann_rerank_k_fail_threshold, "sai_ann_rerank_k"); validateStrictlyPositiveInteger(tables_warn_threshold, "tables_warn_threshold"); validateStrictlyPositiveInteger(tables_failure_threshold, "tables_failure_threshold"); diff --git a/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java b/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java index 463a35238a51..67072c8bbc88 100644 --- a/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java +++ b/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java @@ -40,31 +40,31 @@ public class GuardrailSAIAnnRerankKTest extends GuardrailTester public void before() { defaultWarnThreshold = DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold; - defaultFailThreshold = DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_failure_threshold; + defaultFailThreshold = DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_fail_threshold; defaultMaxTopK = CassandraRelevantProperties.SAI_VECTOR_SEARCH_MAX_TOP_K.getInt(); DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold = WARN_THRESHOLD; - DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_failure_threshold = FAIL_THRESHOLD; + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_fail_threshold = FAIL_THRESHOLD; } @After public void after() { DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold = defaultWarnThreshold; - DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_failure_threshold = defaultFailThreshold; + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_fail_threshold = defaultFailThreshold; CassandraRelevantProperties.SAI_VECTOR_SEARCH_MAX_TOP_K.setInt(defaultMaxTopK); } @Test public void testConfigValidation() { - DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_failure_threshold = -1; + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_fail_threshold = -1; testValidationOfStrictlyPositiveProperty((c, v) -> c.sai_ann_rerank_k_warn_threshold = v.intValue(), "sai_ann_rerank_k_warn_threshold"); DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold = -1; - testValidationOfStrictlyPositiveProperty((c, v) -> c.sai_ann_rerank_k_failure_threshold = v.intValue(), - "sai_ann_rerank_k_failure_threshold"); + testValidationOfStrictlyPositiveProperty((c, v) -> c.sai_ann_rerank_k_fail_threshold = v.intValue(), + "sai_ann_rerank_k_fail_threshold"); } @Test @@ -72,12 +72,12 @@ public void testDefaultValues() { // Reset to defaults DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold = defaultWarnThreshold; - DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_failure_threshold = defaultFailThreshold; + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_fail_threshold = defaultFailThreshold; // Test that default failure threshold is 4 times the max top K int maxTopK = CassandraRelevantProperties.SAI_VECTOR_SEARCH_MAX_TOP_K.getInt(); assertEquals(-1, (int) DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_warn_threshold); - assertEquals(4 * maxTopK, (int) DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_failure_threshold); + assertEquals(4 * maxTopK, (int) DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_fail_threshold); } @Test @@ -118,7 +118,7 @@ public void testDisabledThresholds() throws Throwable assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (WARN_THRESHOLD + 1) + '}'); // Test with failure threshold disabled - DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_failure_threshold = -1; + DatabaseDescriptor.getGuardrailsConfig().sai_ann_rerank_k_fail_threshold = -1; assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (FAIL_THRESHOLD + 1) + '}'); } From 6a3ec582efd1d962d4eade1b7f660898ebdf4da8 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Wed, 19 Feb 2025 14:59:38 -0600 Subject: [PATCH 08/12] Remove support for non-positive rerankk This can be reverted later when we support non-positive numbers, which will mean that we can do a rerankless search. --- .../cassandra/db/filter/ANNOptions.java | 2 +- .../cassandra/db/filter/ANNOptionsTest.java | 25 ++++++++++++------- .../GuardrailSAIAnnRerankKTest.java | 3 ++- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/java/org/apache/cassandra/db/filter/ANNOptions.java b/src/java/org/apache/cassandra/db/filter/ANNOptions.java index b9420610b8ab..ece3dcef6d53 100644 --- a/src/java/org/apache/cassandra/db/filter/ANNOptions.java +++ b/src/java/org/apache/cassandra/db/filter/ANNOptions.java @@ -64,7 +64,7 @@ public static ANNOptions create(@Nullable Integer rerankK) public void validate(QueryState state, int limit) { - if (rerankK == null || rerankK <= 0) + if (rerankK == null) return; if (rerankK < limit) diff --git a/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java b/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java index 1b3c12f316ff..8b414f06bca0 100644 --- a/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java +++ b/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java @@ -91,12 +91,18 @@ public void testParseAndValidate() execute("SELECT * FROM %s WHERE k=0 ORDER BY v ANN OF [1, 1] WITH ann_options = {}"); // correct queries with specific ANN options - execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 0}"); + execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 10}"); + execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 11}"); execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 1000}"); - execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': '0'}"); execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': '1000'}"); - execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': '-1'}"); - execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': '-1000'}"); + + // Queries with invalid ann options that will eventually be valid when we support disabling reranking + assertInvalidThrowMessage("Invalid rerank_k value -1 lesser than limit 100", + InvalidRequestException.class, + "SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 100 WITH ann_options = {'rerank_k': -1}"); + assertInvalidThrowMessage("Invalid rerank_k value 0 lesser than limit 100", + InvalidRequestException.class, + "SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 100 WITH ann_options = {'rerank_k': 0}"); // Queries that exceed the failure threshold for the guardrail. Specifies a protocol version to trigger // validation in the coordinator. @@ -192,12 +198,13 @@ public void testTransport() testTransport("SELECT * FROM %s ORDER BY v ANN OF [1, 1]", ANNOptions.NONE); testTransport("SELECT * FROM %s ORDER BY v ANN OF [1, 1] WITH ann_options = {}", ANNOptions.NONE); + // TODO re-enable this test when we support negative rerank_k values // some random negative values, all should be accepted and not be mapped to NONE - String negativeQuery = "SELECT * FROM %%s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': %d}"; - QuickTheory.qt() - .withExamples(100) - .forAll(integers().allPositive()) - .checkAssert(i -> testTransport(String.format(negativeQuery, -i), ANNOptions.create(-i))); +// String negativeQuery = "SELECT * FROM %%s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': %d}"; +// QuickTheory.qt() +// .withExamples(100) +// .forAll(integers().allPositive()) +// .checkAssert(i -> testTransport(String.format(negativeQuery, -i), ANNOptions.create(-i))); // some random positive values, all should be accepted String positiveQuery = "SELECT * FROM %%s ORDER BY v ANN OF [1, 1] LIMIT %d WITH ann_options = {'rerank_k': %)"); From 76f8e6a5ed760579d8e337bc6153ec8414d8b963 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Wed, 19 Feb 2025 16:16:05 -0600 Subject: [PATCH 09/12] Only consider nodes in keyspace when ensuring messaging version for ann_options --- .../cql3/restrictions/StatementRestrictions.java | 2 +- .../cassandra/cql3/statements/SelectOptions.java | 8 ++++---- .../cassandra/cql3/statements/SelectStatement.java | 2 +- .../org/apache/cassandra/db/filter/ANNOptions.java | 4 ++-- .../org/apache/cassandra/net/MessagingService.java | 10 ++++++---- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java index b20d658867e8..f77f5140c6ba 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java @@ -1000,7 +1000,7 @@ public RowFilter getRowFilter(IndexRegistry indexManager, QueryOptions options, return RowFilter.NONE; } - ANNOptions annOptions = selectOptions.parseANNOptions(); + ANNOptions annOptions = selectOptions.parseANNOptions(options.getKeyspace()); RowFilter rowFilter = RowFilter.builder(indexManager) .buildFromRestrictions(this, table, options, queryState, annOptions); diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectOptions.java b/src/java/org/apache/cassandra/cql3/statements/SelectOptions.java index 4eada73aa563..58ba92f377e7 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectOptions.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectOptions.java @@ -41,23 +41,23 @@ public class SelectOptions extends PropertyDefinitions * @param limit the {@code SELECT} query user-provided limit * @throws InvalidRequestException if any of the options are invalid */ - public void validate(QueryState state, int limit) throws RequestValidationException + public void validate(QueryState state, String keyspace, int limit) throws RequestValidationException { validate(keywords, Collections.emptySet()); - parseANNOptions().validate(state, limit); + parseANNOptions(keyspace).validate(state, limit); } /** * @return the ANN options within these options, or {@link ANNOptions#NONE} if no options are present * @throws InvalidRequestException if the ANN options are invalid */ - public ANNOptions parseANNOptions() throws RequestValidationException + public ANNOptions parseANNOptions(String keyspace) throws RequestValidationException { Map options = getMap(ANN_OPTIONS); return options == null ? ANNOptions.NONE - : ANNOptions.fromMap(options); + : ANNOptions.fromMap(keyspace, options); } /** diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index c90f23feaddf..403a9ff68955 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -426,7 +426,7 @@ public ReadQuery getQuery(QueryState queryState, checkFalse(userOffset != NO_OFFSET, String.format(TOPK_OFFSET_ERROR, userOffset)); } - selectOptions.validate(queryState, userLimit); + selectOptions.validate(queryState, options.getKeyspace(), userLimit); return query; } diff --git a/src/java/org/apache/cassandra/db/filter/ANNOptions.java b/src/java/org/apache/cassandra/db/filter/ANNOptions.java index ece3dcef6d53..c9032fc078c0 100644 --- a/src/java/org/apache/cassandra/db/filter/ANNOptions.java +++ b/src/java/org/apache/cassandra/db/filter/ANNOptions.java @@ -79,10 +79,10 @@ public void validate(QueryState state, int limit) * @param map the map of options in the {@code WITH ANN_OPTION} of a {@code SELECT} query * @return the ANN options in the specified {@code SELECT} options, or {@link #NONE} if no options are present */ - public static ANNOptions fromMap(Map map) + public static ANNOptions fromMap(String keyspace, Map map) { // ensure that all nodes in the cluster are in a version that supports ANN options, including this one - Set badNodes = MessagingService.instance().endpointsWithVersionBelow(MessagingService.VERSION_DS_11); + Set badNodes = MessagingService.instance().endpointsWithVersionBelow(keyspace, MessagingService.VERSION_DS_11); if (MessagingService.current_version < MessagingService.VERSION_DS_11) badNodes.add(FBUtilities.getBroadcastAddressAndPort()); if (!badNodes.isEmpty()) diff --git a/src/java/org/apache/cassandra/net/MessagingService.java b/src/java/org/apache/cassandra/net/MessagingService.java index c3cd867884d2..1865ef63d4ae 100644 --- a/src/java/org/apache/cassandra/net/MessagingService.java +++ b/src/java/org/apache/cassandra/net/MessagingService.java @@ -645,15 +645,17 @@ public void waitUntilListening() throws InterruptedException } /** - * Returns the endpoints that are known to be alive and are using a messaging version older than the given version. + * Returns the endpoints for the given keyspace that are known to be alive and are using a messaging version older + * than the given version. * + * @param keyspace a keyspace * @param version a messaging version - * @return a set of alive endpoints with messaging version below the given version + * @return a set of alive endpoints in the given keyspace with messaging version below the given version */ - public Set endpointsWithVersionBelow(int version) + public Set endpointsWithVersionBelow(String keyspace, int version) { Set nodes = new HashSet<>(); - for (InetAddressAndPort node : StorageService.instance.getTokenMetadata().getAllEndpoints()) + for (InetAddressAndPort node : StorageService.instance.getTokenMetadataForKeyspace(keyspace).getAllEndpoints()) { if (versions.knows(node) && versions.getRaw(node) < version) nodes.add(node); From 57bf94df6032e8f446de482d186a6c6171dd4727 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Fri, 21 Feb 2025 10:47:37 -0600 Subject: [PATCH 10/12] Ensure keyspace param is not null --- .../cassandra/cql3/restrictions/StatementRestrictions.java | 2 +- .../org/apache/cassandra/cql3/statements/SelectStatement.java | 2 +- src/java/org/apache/cassandra/db/filter/ANNOptions.java | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java index f77f5140c6ba..e1f29a20d0b9 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java @@ -1000,7 +1000,7 @@ public RowFilter getRowFilter(IndexRegistry indexManager, QueryOptions options, return RowFilter.NONE; } - ANNOptions annOptions = selectOptions.parseANNOptions(options.getKeyspace()); + ANNOptions annOptions = selectOptions.parseANNOptions(table.keyspace); RowFilter rowFilter = RowFilter.builder(indexManager) .buildFromRestrictions(this, table, options, queryState, annOptions); diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index 403a9ff68955..7b63d6402bba 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -426,7 +426,7 @@ public ReadQuery getQuery(QueryState queryState, checkFalse(userOffset != NO_OFFSET, String.format(TOPK_OFFSET_ERROR, userOffset)); } - selectOptions.validate(queryState, options.getKeyspace(), userLimit); + selectOptions.validate(queryState, table.keyspace, userLimit); return query; } diff --git a/src/java/org/apache/cassandra/db/filter/ANNOptions.java b/src/java/org/apache/cassandra/db/filter/ANNOptions.java index c9032fc078c0..0ffaf951b958 100644 --- a/src/java/org/apache/cassandra/db/filter/ANNOptions.java +++ b/src/java/org/apache/cassandra/db/filter/ANNOptions.java @@ -81,6 +81,7 @@ public void validate(QueryState state, int limit) */ public static ANNOptions fromMap(String keyspace, Map map) { + assert keyspace != null; // ensure that all nodes in the cluster are in a version that supports ANN options, including this one Set badNodes = MessagingService.instance().endpointsWithVersionBelow(keyspace, MessagingService.VERSION_DS_11); if (MessagingService.current_version < MessagingService.VERSION_DS_11) From 55c0d3eee104e92d9fd7ac691fb4cad1addf985b Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Fri, 21 Feb 2025 10:48:07 -0600 Subject: [PATCH 11/12] Fix license --- .../guardrails/GuardrailSAIAnnRerankKTest.java | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java b/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java index 3dd6e69de227..8bcaf20a4041 100644 --- a/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java +++ b/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java @@ -1,13 +1,11 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Copyright DataStax, Inc. * - * http://www.apache.org/licenses/LICENSE-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, From ab79d27c1fba168cd613eb8ab27c245fb7840093 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Mon, 24 Feb 2025 11:33:05 -0600 Subject: [PATCH 12/12] Address review feedback --- .../cassandra/guardrails/GuardrailSAIAnnRerankKTest.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java b/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java index 8bcaf20a4041..995345b3d2fd 100644 --- a/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java +++ b/test/unit/org/apache/cassandra/guardrails/GuardrailSAIAnnRerankKTest.java @@ -88,7 +88,7 @@ public void testSAIAnnRerankKThresholds() throws Throwable // Test values below and at warning threshold assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': 10}"); assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (WARN_THRESHOLD - 1) + '}'); - assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (WARN_THRESHOLD) + '}'); + assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + WARN_THRESHOLD + '}'); // Test values between warning and failure thresholds assertWarns(String.format("ANN options specifies rerank_k=%d, this exceeds the warning threshold of %d.", @@ -121,7 +121,7 @@ public void testDisabledThresholds() throws Throwable assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {'rerank_k': " + (FAIL_THRESHOLD + 1) + '}'); } - @Ignore + @Ignore // TODO: e-enable this test when we support negative rerank_k values public void testNegativeRerankK() throws Throwable { createTable("CREATE TABLE %s (k int PRIMARY KEY, v vector)"); @@ -138,7 +138,7 @@ public void testMissingRerankK() throws Throwable createTable("CREATE TABLE %s (k int PRIMARY KEY, v vector)"); createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); - // Queries without rerank_k should be valid + // Queries without rerank_k should be valid and not trigger warnings. assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10"); assertValid("SELECT * FROM %s ORDER BY v ANN OF [1.0, 1.0, 1.0] LIMIT 10 WITH ann_options = {}"); }