Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CNDB-12922: Implement rerank_k in SAI ANN queries #1562

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ public enum CassandraRelevantProperties
* The current messaging version. This is used when we add new messaging versions without adopting them immediately,
* or to force the node to use a specific version for testing purposes.
*/
DS_CURRENT_MESSAGING_VERSION("ds.current_messaging_version", Integer.toString(MessagingService.VERSION_DS_10));
DS_CURRENT_MESSAGING_VERSION("ds.current_messaging_version", Integer.toString(MessagingService.VERSION_DS_11));

CassandraRelevantProperties(String key, String defaultVal)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.cassandra.db.filter.ANNOptions;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.exceptions.RequestValidationException;
import org.apache.cassandra.service.QueryState;

/**
* {@code WITH option1=... AND option2=...} options for SELECT statements.
Expand All @@ -36,13 +37,14 @@ public class SelectOptions extends PropertyDefinitions
/**
* Validates all the {@code SELECT} options.
*
* @param state the query state
* @param limit the {@code SELECT} query user-provided limit
* @throws InvalidRequestException if any of the options are invalid
*/
public void validate(int limit) throws RequestValidationException
public void validate(QueryState state, int limit) throws RequestValidationException
{
validate(keywords, Collections.emptySet());
parseANNOptions().validate(limit);
parseANNOptions().validate(state, limit);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ public ReadQuery getQuery(QueryState queryState,
checkFalse(userOffset != NO_OFFSET, String.format(TOPK_OFFSET_ERROR, userOffset));
}

selectOptions.validate(userLimit);
selectOptions.validate(queryState, userLimit);

return query;
}
Expand Down
11 changes: 9 additions & 2 deletions src/java/org/apache/cassandra/db/filter/ANNOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@

import org.apache.cassandra.db.TypeSizes;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.guardrails.Guardrails;
import org.apache.cassandra.io.util.DataInputPlus;
import org.apache.cassandra.io.util.DataOutputPlus;
import org.apache.cassandra.locator.InetAddressAndPort;
import org.apache.cassandra.net.MessagingService;
import org.apache.cassandra.service.QueryState;
import org.apache.cassandra.utils.FBUtilities;

/**
Expand Down Expand Up @@ -60,10 +62,15 @@ public static ANNOptions create(@Nullable Integer rerankK)
return rerankK == null ? NONE : new ANNOptions(rerankK);
}

public void validate(int limit)
public void validate(QueryState state, int limit)
{
if (rerankK != null && rerankK > 0 && rerankK < limit)
if (rerankK == null || rerankK <= 0)
return;

if (rerankK < limit)
throw new InvalidRequestException(String.format("Invalid rerank_k value %d lesser than limit %d", rerankK, limit));

Guardrails.annRerankKMaxValue.guard(rerankK, "ANN Option rerank_k", false, state);
}

/**
Expand Down
11 changes: 11 additions & 0 deletions src/java/org/apache/cassandra/guardrails/Guardrails.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ what, formatSize(v), formatSize(t)))
format("%s has a vector of %s dimensions, this exceeds the %s threshold of %s.",
what, value, isWarning ? "warning" : "failure", threshold));

/**
* Guardrail on the maximum value for the rerank_k parameter, an ANN query option.
*/
public static final Threshold annRerankKMaxValue =
factory.threshold("sai_ann_rerank_k_max_value",
() -> config.sai_ann_rerank_k_warn_threshold,
() -> config.sai_ann_rerank_k_failure_threshold,
(isWarning, what, value, threshold) ->
format("%s specifies rerank_k=%s, this exceeds the %s threshold of %s.",
what, value, isWarning ? "warning" : "failure", threshold));

public static final DisableFlag readBeforeWriteListOperationsEnabled =
factory.disableFlag("read_before_write_list_operations",
() -> !config.read_before_write_list_operations_enabled,
Expand Down
12 changes: 12 additions & 0 deletions src/java/org/apache/cassandra/guardrails/GuardrailsConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import org.apache.cassandra.config.CassandraRelevantProperties;
import org.apache.cassandra.config.Config;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.cql3.statements.schema.TableAttributes;
Expand Down Expand Up @@ -81,6 +82,8 @@ public class GuardrailsConfig
public volatile Boolean read_before_write_list_operations_enabled;
public volatile Integer vector_dimensions_warn_threshold;
public volatile Integer vector_dimensions_failure_threshold;
public volatile Integer sai_ann_rerank_k_warn_threshold;
public volatile Integer sai_ann_rerank_k_failure_threshold;

// Legacy 2i guardrail
public volatile Integer secondary_index_per_table_failure_threshold;
Expand Down Expand Up @@ -165,6 +168,11 @@ public void applyConfig()
enforceDefault(tombstone_warn_threshold, v -> tombstone_warn_threshold = v, 1000, 1000);
enforceDefault(tombstone_failure_threshold, v -> tombstone_failure_threshold = v, 100000, 100000);

// Default to no warning and failure at 4 times the maxTopK value
int maxTopK = CassandraRelevantProperties.SAI_VECTOR_SEARCH_MAX_TOP_K.getInt();
enforceDefault(sai_ann_rerank_k_warn_threshold, v -> sai_ann_rerank_k_warn_threshold = v, -1, -1);
enforceDefault(sai_ann_rerank_k_failure_threshold, v -> sai_ann_rerank_k_failure_threshold = v, 4 * maxTopK, 4 * maxTopK);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another approach could be to cap the ratio of LIMIT to rerank_k.


// for write requests
enforceDefault(logged_batch_enabled, v -> logged_batch_enabled = v, true, true);
enforceDefault(batch_size_warn_threshold_in_kb, v -> batch_size_warn_threshold_in_kb = v, 64, 64);
Expand Down Expand Up @@ -269,6 +277,10 @@ public void validate()
validateStrictlyPositiveInteger(vector_dimensions_failure_threshold, "vector_dimensions_failure_threshold");
validateWarnLowerThanFail(vector_dimensions_warn_threshold, vector_dimensions_failure_threshold, "vector_dimensions");

validateStrictlyPositiveInteger(sai_ann_rerank_k_warn_threshold, "sai_ann_rerank_k_warn_threshold");
validateStrictlyPositiveInteger(sai_ann_rerank_k_failure_threshold, "sai_ann_rerank_k_failure_threshold");
validateWarnLowerThanFail(sai_ann_rerank_k_warn_threshold, sai_ann_rerank_k_failure_threshold, "sai_ann_rerank_k");

validateStrictlyPositiveInteger(tables_warn_threshold, "tables_warn_threshold");
validateStrictlyPositiveInteger(tables_failure_threshold, "tables_failure_threshold");
validateWarnLowerThanFail(tables_warn_threshold, tables_failure_threshold, "tables");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -711,10 +711,6 @@ public void validate(ReadCommand command) throws InvalidRequestException
throw new InvalidRequestException(String.format("SAI based ORDER BY clause requires a LIMIT that is not greater than %s. LIMIT was %s",
MAX_TOP_K, command.limits().isUnlimited() ? "NO LIMIT" : command.limits().count()));

ANNOptions annOptions = command.rowFilter().annOptions();
if (annOptions != ANNOptions.NONE)
throw new InvalidRequestException("SAI doesn't support ANN options yet.");

indexContext.validate(command.rowFilter());
}

Expand Down
7 changes: 4 additions & 3 deletions src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,11 @@ public String toString()
* the number of candidates, the more nodes we expect to visit just to find
* results that are in that set.)
*/
public double estimateAnnSearchCost(int limit, int candidates)
public double estimateAnnSearchCost(Orderer orderer, int limit, int candidates)
{
IndexSearcher searcher = getIndexSearcher();
return ((V2VectorIndexSearcher) searcher).estimateAnnSearchCost(limit, candidates);
V2VectorIndexSearcher searcher = (V2VectorIndexSearcher) getIndexSearcher();
int rerankK = orderer.rerankKFor(limit, searcher.getCompression());
return searcher.estimateAnnSearchCost(rerankK, candidates);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(Orderer orderer, E
{
if (segment.intersects(keyRange))
{
// Note that the proportionality is not used when the user supplies a rerank_k value in the
// ANN_OPTIONS map.
var segmentLimit = segment.proportionalAnnLimit(limit, totalRows);
iterators.add(segment.orderBy(orderer, slice, keyRange, context, segmentLimit));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderBy(Orderer orderer, Express
if (orderer.vector == null)
throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression during ANN index query: " + orderer));

int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression());
int rerankK = orderer.rerankKFor(limit, graph.getCompression());
var queryVector = vts.createFloatVector(orderer.vector);

var result = searchInternal(keyRange, context, queryVector, limit, rerankK, 0);
Expand Down Expand Up @@ -428,9 +428,8 @@ public double cost()
}
}

public double estimateAnnSearchCost(int limit, int candidates)
public double estimateAnnSearchCost(int rerankK, int candidates)
{
int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression());
var estimate = estimateCost(rerankK, candidates);
return estimate.cost();
}
Expand Down Expand Up @@ -472,7 +471,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(SSTableReader rea
if (keys.isEmpty())
return CloseableIterator.emptyIterator();

int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression());
int rerankK = orderer.rerankKFor(limit, graph.getCompression());
// Convert PKs to segment row ids and map to ordinals, skipping any that don't exist in this segment
var segmentOrdinalPairs = flatmapPrimaryKeysToBitsAndRows(keys);
var numRows = segmentOrdinalPairs.size();
Expand Down Expand Up @@ -611,9 +610,9 @@ public static double logBase2(double number) {
return Math.log(number) / Math.log(2);
}

private int getRawExpectedNodes(int limit, int nPermittedOrdinals)
private int getRawExpectedNodes(int rerankK, int nPermittedOrdinals)
{
return VectorMemtableIndex.expectedNodesVisited(limit, nPermittedOrdinals, graph.size());
return VectorMemtableIndex.expectedNodesVisited(rerankK, nPermittedOrdinals, graph.size());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ public void remove(ByteBuffer term, T key)
/**
* @return an itererator over {@link PrimaryKeyWithSortKey} in the graph's {@link SearchResult} order
*/
public CloseableIterator<SearchResult.NodeScore> search(QueryContext context, VectorFloat<?> queryVector, int limit, float threshold, Bits toAccept)
public CloseableIterator<SearchResult.NodeScore> search(QueryContext context, VectorFloat<?> queryVector, int limit, int rerankK, float threshold, Bits toAccept)
{
VectorValidation.validateIndexable(queryVector, similarityFunction);

Expand All @@ -326,7 +326,6 @@ public CloseableIterator<SearchResult.NodeScore> search(QueryContext context, Ve
try
{
var ssf = SearchScoreProvider.exact(queryVector, similarityFunction, vectorValues);
var rerankK = sourceModel.rerankKFor(limit, VectorCompression.NO_COMPRESSION);
var result = searcher.search(ssf, limit, rerankK, threshold, 0.0f, bits);
Tracing.trace("ANN search for {}/{} visited {} nodes, reranked {} to return {} results from {}",
limit, rerankK, result.getVisitedCount(), result.getRerankedCount(), result.getNodes().length, source);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ public KeyRangeIterator search(QueryContext context, Expression expr, AbstractBo
float threshold = expr.getEuclideanSearchThreshold();

SortingIterator.Builder<PrimaryKey> keyQueue;
try (var pkIterator = searchInternal(context, qv, keyRange, graph.size(), threshold))
try (var pkIterator = searchInternal(context, qv, keyRange, graph.size(), graph.size(), threshold))
{
keyQueue = new SortingIterator.Builder<>();
while (pkIterator.hasNext())
Expand Down Expand Up @@ -223,14 +223,16 @@ public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext conte
assert orderer.isANN() : "Only ANN is supported for vector search, received " + orderer.operator;

var qv = vts.createFloatVector(orderer.vector);
var rerankK = orderer.rerankKFor(limit, VectorCompression.NO_COMPRESSION);

return List.of(searchInternal(context, qv, keyRange, limit, 0));
return List.of(searchInternal(context, qv, keyRange, limit, rerankK, 0));
}

private CloseableIterator<PrimaryKeyWithSortKey> searchInternal(QueryContext context,
VectorFloat<?> queryVector,
AbstractBounds<PartitionPosition> keyRange,
int limit,
int rerankK,
float threshold)
{
Bits bits;
Expand All @@ -256,11 +258,11 @@ private CloseableIterator<PrimaryKeyWithSortKey> searchInternal(QueryContext con
if (resultKeys.isEmpty())
return CloseableIterator.emptyIterator();

int bruteForceRows = maxBruteForceRows(limit, resultKeys.size(), graph.size());
logger.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, LIMIT {}",
resultKeys.size(), bruteForceRows, graph.size(), limit);
Tracing.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, LIMIT {}",
resultKeys.size(), bruteForceRows, graph.size(), limit);
int bruteForceRows = maxBruteForceRows(rerankK, resultKeys.size(), graph.size());
logger.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, rerankK {}, LIMIT {}",
resultKeys.size(), bruteForceRows, graph.size(), rerankK, limit);
Tracing.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, rerankK {}, LIMIT {}",
resultKeys.size(), bruteForceRows, graph.size(), rerankK, limit);
if (resultKeys.size() <= bruteForceRows)
// When we have a threshold, we only need to filter the results, not order them, because it means we're
// evaluating a boolean predicate in the SAI pipeline that wants to collate by PK
Expand All @@ -272,7 +274,7 @@ private CloseableIterator<PrimaryKeyWithSortKey> searchInternal(QueryContext con
bits = new KeyRangeFilteringBits(keyRange);
}

var nodeScoreIterator = graph.search(context, queryVector, limit, threshold, bits);
var nodeScoreIterator = graph.search(context, queryVector, limit, rerankK, threshold, bits);
return new NodeScoreToScoredPrimaryKeyIterator(nodeScoreIterator);
}

Expand Down Expand Up @@ -305,9 +307,10 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(QueryContext cont
relevantOrdinals.add(i);
});

int maxBruteForceRows = maxBruteForceRows(limit, relevantOrdinals.size(), graph.size());
Tracing.logAndTrace(logger, "{} rows relevant to current memtable out of {} materialized by SAI; max brute force rows is {} for memtable index with {} nodes, LIMIT {}",
relevantOrdinals.size(), keys.size(), maxBruteForceRows, graph.size(), limit);
int rerankK = orderer.rerankKFor(limit, VectorCompression.NO_COMPRESSION);
int maxBruteForceRows = maxBruteForceRows(rerankK, relevantOrdinals.size(), graph.size());
Tracing.logAndTrace(logger, "{} rows relevant to current memtable out of {} materialized by SAI; max brute force rows is {} for memtable index with {} nodes, rerankK {}",
relevantOrdinals.size(), keys.size(), maxBruteForceRows, graph.size(), rerankK);

// convert the expression value to query vector
var qv = vts.createFloatVector(orderer.vector);
Expand All @@ -319,7 +322,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(QueryContext cont
return orderByBruteForce(qv, keysInGraph);
}
// indexed path
var nodeScoreIterator = graph.search(context, qv, limit, 0, relevantOrdinals::contains);
var nodeScoreIterator = graph.search(context, qv, limit, rerankK, 0, relevantOrdinals::contains);
return new NodeScoreToScoredPrimaryKeyIterator(nodeScoreIterator);
}

Expand Down Expand Up @@ -375,15 +378,15 @@ private PrimaryKeyWithScore scoreKey(VectorSimilarityFunction similarityFunction
return new PrimaryKeyWithScore(indexContext, mt, key, score);
}

private int maxBruteForceRows(int limit, int nPermittedOrdinals, int graphSize)
private int maxBruteForceRows(int rerankK, int nPermittedOrdinals, int graphSize)
{
int expectedNodesVisited = expectedNodesVisited(limit, nPermittedOrdinals, graphSize);
return min(max(limit, expectedNodesVisited), GLOBAL_BRUTE_FORCE_ROWS);
int expectedNodesVisited = expectedNodesVisited(rerankK, nPermittedOrdinals, graphSize);
return min(max(rerankK, expectedNodesVisited), GLOBAL_BRUTE_FORCE_ROWS);
}

public int estimateAnnNodesVisited(int limit, int nPermittedOrdinals)
public int estimateAnnNodesVisited(int rerankK, int nPermittedOrdinals)
{
return expectedNodesVisited(limit, nPermittedOrdinals, graph.size());
return expectedNodesVisited(rerankK, nPermittedOrdinals, graph.size());
}

/**
Expand All @@ -395,9 +398,9 @@ public int estimateAnnNodesVisited(int limit, int nPermittedOrdinals)
* !!! roughly `degree` times larger than the number of nodes whose edge lists we load!
* !!!
*/
public static int expectedNodesVisited(int limit, int nPermittedOrdinals, int graphSize)
public static int expectedNodesVisited(int rerankK, int nPermittedOrdinals, int graphSize)
{
var K = limit;
var K = rerankK;
var B = min(nPermittedOrdinals, graphSize);
var N = graphSize;
// These constants come from running many searches on a variety of datasets and graph sizes.
Expand All @@ -412,7 +415,7 @@ public static int expectedNodesVisited(int limit, int nPermittedOrdinals, int gr
// If we need to make this even more accurate, the relationship to B and to log(N) may be the best
// places to start.
var raw = (int) (100 + 0.025 * pow(log(N), 2) * pow(K, 0.95) * ((double) N / B));
return ensureSaneEstimate(raw, limit, graphSize);
return ensureSaneEstimate(raw, rerankK, graphSize);
}

public static int ensureSaneEstimate(int rawEstimate, int rerankK, int graphSize)
Expand Down
Loading
Loading