Skip to content

Commit

Permalink
Use a random pivot for Quickselect
Browse files Browse the repository at this point in the history
  • Loading branch information
alexklibisz committed Nov 26, 2023
1 parent e6208a5 commit eb51d04
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class KthGreatestBenchmarkFixtures {
val numDocs = 60000
val shortCounts: Array[Short] = (0 until numDocs).map(_ => rng.nextInt(Short.MaxValue).toShort).toArray
val copy = new Array[Short](shortCounts.length)
val expected = shortCounts.sorted.reverse.apply(k)
}

class KthGreatestBenchmarks {
Expand All @@ -24,8 +25,8 @@ class KthGreatestBenchmarks {
@Measurement(time = 5, iterations = 5)
def sortBaseline(f: KthGreatestBenchmarkFixtures): Unit = {
val sorted = f.shortCounts.sorted
val _ = sorted.apply(f.shortCounts.length - f.k)
()
val actual = sorted.apply(f.shortCounts.length - f.k)
require(actual == f.expected, (actual, f.expected))
}

@Benchmark
Expand All @@ -34,8 +35,8 @@ class KthGreatestBenchmarks {
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def kthGreatest(f: KthGreatestBenchmarkFixtures): Unit = {
KthGreatest.kthGreatest(f.shortCounts, f.k)
()
val actual = KthGreatest.kthGreatest(f.shortCounts, f.k)
require(actual.kthGreatest == f.expected, (actual.kthGreatest, f.expected))
}

@Benchmark
Expand All @@ -45,7 +46,7 @@ class KthGreatestBenchmarks {
@Measurement(time = 5, iterations = 5)
def unnikedRecursive(f: KthGreatestBenchmarkFixtures): Unit = {
System.arraycopy(f.shortCounts, 0, f.copy, 0, f.copy.length)
QuickSelect.selectRecursive(f.copy, f.k)
()
val actual = QuickSelect.selectRecursive(f.copy, f.k)
require(actual == f.expected, (actual, f.expected))
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package com.klibisz.elastiknn.search;

import java.util.Random;

public class QuickSelect {

private static final Random rng = new Random(0);

public static short selectRecursive(short[] array, int n) {
return recursive(array, 0, array.length - 1, n);
}
Expand All @@ -12,7 +16,7 @@ private static short recursive(short[] array, int left, int right, int k) {
}

// select a pivotIndex between left and right
int pivotIndex = left + (right - left) / 2;
int pivotIndex = left + rng.nextInt(right - left);
pivotIndex = partition(array, left, right, pivotIndex);
// The pivot is in its final sorted position
if (k == pivotIndex) {
Expand Down

0 comments on commit eb51d04

Please sign in to comment.