From 63c778cce05e77fd1d45226283390db8e01c3734 Mon Sep 17 00:00:00 2001 From: Ethan Emoto Date: Fri, 17 Jan 2025 16:53:46 -0800 Subject: [PATCH] Add IT for testing rescore enabled and disabled Signed-off-by: Ethan Emoto Signed-off-by: Ethan Emoto --- .../nativelib/NativeEngineKnnVectorQuery.java | 2 +- .../index/query/rescore/RescoreContext.java | 5 +- .../knn/integ/ModeAndCompressionIT.java | 119 ++++++++++++++++++ 3 files changed, 122 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 9b9354abd..1ffaa804d 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -60,7 +60,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo List perLeafResults; RescoreContext rescoreContext = knnQuery.getRescoreContext(); final int finalK = knnQuery.getK(); - if (rescoreContext == null || rescoreContext.isRescoreDisabled()) { + if (rescoreContext == null || !rescoreContext.isRescoreEnabled()) { perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK); } else { boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()); diff --git a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java index 2884350e0..4e89b1b04 100644 --- a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java +++ b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java @@ -51,12 +51,11 @@ public final class RescoreContext { * Flag to track whether rescoring has been disabled by the query parameters. */ @Builder.Default - private boolean rescoreDisabled = false; + private boolean rescoreEnabled = true; - // Rescore context to be used when rescoring should be explicitly disabled public static final RescoreContext EXPLICITLY_DISABLED_RESCORE_CONTEXT = RescoreContext.builder() .oversampleFactor(DEFAULT_OVERSAMPLE_FACTOR) - .rescoreDisabled(true) + .rescoreEnabled(false) .build(); /** diff --git a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java index 6d24cc3f1..ef2cee8f2 100644 --- a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java +++ b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java @@ -206,6 +206,125 @@ public void testIndexCreation_whenValid_ThenSucceed() { } } + @SneakyThrows + public void testQueryRescoreEnabledAndDisabled() { + XContentBuilder builder; + String mode = Mode.ON_DISK.getName(); + String compressionLevel = CompressionLevel.x32.getName(); + String indexName = INDEX_NAME + compressionLevel; + builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", DIMENSION) + .field(MODE_PARAMETER, mode) + .field(COMPRESSION_LEVEL_PARAMETER, compressionLevel) + .endObject() + .endObject() + .endObject(); + String mapping = builder.toString(); + validateIndex(indexName, mapping); + logger.info("Compression level {}", compressionLevel); + // Do exact search and gather right scores for the documents + Response exactSearchResponse = searchKNNIndex( + indexName, + XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("script_score") + .startObject("query") + .field("match_all") + .startObject() + .endObject() + .endObject() + .startObject("script") + .field("source", "knn_score") + .field("lang", "knn") + .startObject("params") + .field("field", FIELD_NAME) + .field("query_value", TEST_VECTOR) + .field("space_type", SpaceType.L2.getValue()) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(), + K + ); + assertOK(exactSearchResponse); + String exactSearchResponseBody = EntityUtils.toString(exactSearchResponse.getEntity()); + List exactSearchKnnResults = parseSearchResponseScore(exactSearchResponseBody, FIELD_NAME); + assertEquals(NUM_DOCS, exactSearchKnnResults.size()); + // Search without rescore + Response response = searchKNNIndex( + indexName, + XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(FIELD_NAME) + .field("vector", TEST_VECTOR) + .field("k", K) + .field(RescoreParser.RESCORE_PARAMETER, false) + .endObject() + .endObject() + .endObject() + .endObject(), + K + ); + assertOK(response); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponseScore(responseBody, FIELD_NAME); + assertEquals(K, knnResults.size()); + Assert.assertNotEquals(exactSearchKnnResults, knnResults); + // Search with explicit rescore + response = searchKNNIndex( + indexName, + XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(FIELD_NAME) + .field("vector", TEST_VECTOR) + .field("k", K) + .startObject(RescoreParser.RESCORE_PARAMETER) + .field(RescoreParser.RESCORE_OVERSAMPLE_PARAMETER, 2.0f) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(), + K + ); + assertOK(response); + responseBody = EntityUtils.toString(response.getEntity()); + knnResults = parseSearchResponseScore(responseBody, FIELD_NAME); + assertEquals(K, knnResults.size()); + Assert.assertEquals(exactSearchKnnResults, knnResults); + // Search with default rescore + response = searchKNNIndex( + indexName, + XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(FIELD_NAME) + .field("vector", TEST_VECTOR) + .field("k", K) + .endObject() + .endObject() + .endObject() + .endObject(), + K + ); + assertOK(response); + responseBody = EntityUtils.toString(response.getEntity()); + knnResults = parseSearchResponseScore(responseBody, FIELD_NAME); + assertEquals(K, knnResults.size()); + Assert.assertEquals(exactSearchKnnResults, knnResults); + } + @SneakyThrows public void testDeletedDocsWithSegmentMerge_whenValid_ThenSucceed() { XContentBuilder builder;