From e6208a56c996fef0c515585e91758d4b5532647f Mon Sep 17 00:00:00 2001 From: Alex Klibisz <8015228+alexklibisz@users.noreply.github.com> Date: Sun, 26 Nov 2023 15:36:14 -0800 Subject: [PATCH] Use shorts instead of ints --- .../jmhbenchmarks/KthGreatestBenchmarks.scala | 11 +++++------ .../klibisz/elastiknn/search/QuickSelect.java | 16 ++++++---------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/elastiknn-jmh-benchmarks/src/main/scala/com/klibisz/elastiknn/jmhbenchmarks/KthGreatestBenchmarks.scala b/elastiknn-jmh-benchmarks/src/main/scala/com/klibisz/elastiknn/jmhbenchmarks/KthGreatestBenchmarks.scala index b723d702e..bc3688973 100644 --- a/elastiknn-jmh-benchmarks/src/main/scala/com/klibisz/elastiknn/jmhbenchmarks/KthGreatestBenchmarks.scala +++ b/elastiknn-jmh-benchmarks/src/main/scala/com/klibisz/elastiknn/jmhbenchmarks/KthGreatestBenchmarks.scala @@ -11,9 +11,8 @@ class KthGreatestBenchmarkFixtures { val rng = new Random(0) val k = 1000 val numDocs = 60000 - val intCounts: Array[Int] = (0 until numDocs).map(_ => rng.nextInt(Short.MaxValue)).toArray - val shortCounts: Array[Short] = intCounts.map(_.toShort) - val copy = new Array[Int](intCounts.length) + val shortCounts: Array[Short] = (0 until numDocs).map(_ => rng.nextInt(Short.MaxValue).toShort).toArray + val copy = new Array[Short](shortCounts.length) } class KthGreatestBenchmarks { @@ -24,8 +23,8 @@ class KthGreatestBenchmarks { @Warmup(time = 5, iterations = 5) @Measurement(time = 5, iterations = 5) def sortBaseline(f: KthGreatestBenchmarkFixtures): Unit = { - val sorted = f.intCounts.sorted - val _ = sorted.apply(f.intCounts.length - f.k) + val sorted = f.shortCounts.sorted + val _ = sorted.apply(f.shortCounts.length - f.k) () } @@ -45,7 +44,7 @@ class KthGreatestBenchmarks { @Warmup(time = 5, iterations = 5) @Measurement(time = 5, iterations = 5) def unnikedRecursive(f: KthGreatestBenchmarkFixtures): Unit = { - System.arraycopy(f.intCounts, 0, f.copy, 0, f.copy.length) + System.arraycopy(f.shortCounts, 0, f.copy, 0, f.copy.length) QuickSelect.selectRecursive(f.copy, f.k) () } diff --git a/elastiknn-lucene/src/main/java/com/klibisz/elastiknn/search/QuickSelect.java b/elastiknn-lucene/src/main/java/com/klibisz/elastiknn/search/QuickSelect.java index 9afba9e4c..c862319e5 100644 --- a/elastiknn-lucene/src/main/java/com/klibisz/elastiknn/search/QuickSelect.java +++ b/elastiknn-lucene/src/main/java/com/klibisz/elastiknn/search/QuickSelect.java @@ -2,17 +2,17 @@ public class QuickSelect { - public static int selectRecursive(int[] array, int n) { + public static short selectRecursive(short[] array, int n) { return recursive(array, 0, array.length - 1, n); } - private static int recursive(int[] array, int left, int right, int k) { + private static short recursive(short[] array, int left, int right, int k) { if (left == right) { // If the list contains only one element, return array[left]; // return that element } // select a pivotIndex between left and right - int pivotIndex = middlePivot(left, right); + int pivotIndex = left + (right - left) / 2; pivotIndex = partition(array, left, right, pivotIndex); // The pivot is in its final sorted position if (k == pivotIndex) { @@ -24,7 +24,7 @@ private static int recursive(int[] array, int left, int right, int k) { } } - private static int partition(int[] array, int left, int right, int pivotIndex) { + private static int partition(short[] array, int left, int right, int pivotIndex) { int pivotValue = array[pivotIndex]; swap(array, pivotIndex, right); // move pivot to end int storeIndex = left; @@ -38,13 +38,9 @@ private static int partition(int[] array, int left, int right, int pivotIndex) { return storeIndex; } - private static void swap(int[] array, int a, int b) { - int tmp = array[a]; + private static void swap(short[] array, int a, int b) { + short tmp = array[a]; array[a] = array[b]; array[b] = tmp; } - - private static int middlePivot(int left, int right) { - return left + (right - left) / 2; - } }