Skip to content

Commit

Permalink
Performance: simplify and optimize kth-greatest computation (96% reca…
Browse files Browse the repository at this point in the history
…ll at 195 qps) (#616)
  • Loading branch information
alexklibisz authored Dec 1, 2023
1 parent 504589b commit ea383d8
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 149 deletions.
2 changes: 1 addition & 1 deletion docs/pages/performance/fashion-mnist/plot.b64

Large diffs are not rendered by default.

Binary file modified docs/pages/performance/fashion-mnist/plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 8 additions & 8 deletions docs/pages/performance/fashion-mnist/results.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
|Model|Parameters|Recall|Queries per Second|
|---|---|---|---|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=0|0.378|337.457|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=0|0.446|281.828|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=3|0.634|272.814|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=3|0.716|232.698|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=0|0.767|303.686|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=0|0.846|254.121|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=3|0.922|215.233|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=3|0.960|190.689|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=0|0.379|353.162|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=0|0.447|295.007|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=3|0.634|286.531|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=3|0.716|245.690|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=0|0.767|312.826|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=0|0.846|265.204|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=3|0.921|221.817|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=3|0.960|195.653|
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.klibisz.elastiknn.search;

import org.apache.lucene.search.KthGreatest;

/**
* Use an array of counts to count hits. The index of the array is the doc id.
* Hopefully there's a way to do this that doesn't require O(num docs in segment) time and memory,
Expand All @@ -14,29 +12,36 @@ public class ArrayHitCounter implements HitCounter {
private int minKey;
private int maxKey;

private short maxValue;

public ArrayHitCounter(int capacity) {
counts = new short[capacity];
numHits = 0;
minKey = capacity;
maxKey = 0;
maxValue = 0;
}

@Override
public void increment(int key) {
if (counts[key]++ == 0) {
short after = ++counts[key];
if (after == 1) {
numHits++;
minKey = Math.min(key, minKey);
maxKey = Math.max(key, maxKey);
}
if (after > maxValue) maxValue = after;
}

@Override
public void increment(int key, short count) {
if ((counts[key] += count) == count) {
short after = (counts[key] += count);
if (after == count) {
numHits++;
minKey = Math.min(key, minKey);
maxKey = Math.max(key, maxKey);
}
if (after > maxValue) maxValue = after;
}

@Override
Expand Down Expand Up @@ -70,8 +75,34 @@ public int maxKey() {
}

@Override
public KthGreatest.Result kthGreatest(int k) {
return KthGreatest.kthGreatest(counts, Math.min(k, counts.length - 1));
}
public KthGreatestResult kthGreatest(int k) {
// Find the kth greatest document hit count in O(n) time and O(n) space.
// Though the space is typically negligibly small in practice.
// This implementation exploits the fact that we're specifically counting document hit counts.
// Counts are integers, and they're likely to be pretty small, since we're unlikely to match
// the same document many times.

// Start by building a histogram of all counts.
// e.g., if the counts are [0, 4, 1, 1, 2],
// then the histogram is [1, 2, 1, 0, 1],
// because 0 occurs once, 1 occurs twice, 2 occurs once, 3 occurs zero times, and 4 occurs once.
short[] hist = new short[maxValue + 1];
for (short c: counts) hist[c]++;

// Now we start at the max value and iterate backwards through the histogram,
// accumulating counts of counts until we've exceeded k.
int numGreaterEqual = 0;
short kthGreatest = maxValue;
while (kthGreatest > 0) {
numGreaterEqual += hist[kthGreatest];
if (numGreaterEqual > k) break;
else kthGreatest--;
}

// Finally we find the number that were greater than the kth greatest count.
// There's a special case if kthGreatest is zero, then the number that were greater is the number of hits.
int numGreater = numGreaterEqual - hist[kthGreatest];
if (kthGreatest == 0) numGreater = numHits;
return new KthGreatestResult(kthGreatest, numGreater, numHits);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.klibisz.elastiknn.search;

import org.apache.lucene.search.KthGreatest;

public final class EmptyHitCounter implements HitCounter {

@Override
Expand Down Expand Up @@ -41,7 +39,7 @@ public int maxKey() {
}

@Override
public KthGreatest.Result kthGreatest(int k) {
return new KthGreatest.Result((short) 0, 0, 0);
public KthGreatestResult kthGreatest(int k) {
return new KthGreatestResult((short) 0, 0, 0);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.klibisz.elastiknn.search;

import org.apache.lucene.search.KthGreatest;

/**
* Abstraction for counting hits for a particular query.
*/
Expand All @@ -23,6 +21,6 @@ public interface HitCounter {

int maxKey();

KthGreatest.Result kthGreatest(int k);
KthGreatestResult kthGreatest(int k);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.klibisz.elastiknn.search;

public class KthGreatestResult {
public final short kthGreatest;
public final int numGreaterThan;
public final int numNonZero;
public KthGreatestResult(short kthGreatest, int numGreaterThan, int numNonZero) {
this.kthGreatest = kthGreatest;
this.numGreaterThan = numGreaterThan;
this.numNonZero = numNonZero;
}

@Override
public boolean equals(Object o) {
if (o == this) {
return true;
} else if (!(o instanceof KthGreatestResult other)) {
return false;
} else {
return kthGreatest == other.kthGreatest && numGreaterThan == other.numGreaterThan && numNonZero == other.numNonZero;
}
}

@Override
public String toString() {
return String.format("KthGreatestResult(kthGreatest=%d, numGreaterThan=%d, numNonZero=%d)", kthGreatest, numGreaterThan, numNonZero);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.klibisz.elastiknn.search.ArrayHitCounter;
import com.klibisz.elastiknn.search.EmptyHitCounter;
import com.klibisz.elastiknn.search.HitCounter;
import com.klibisz.elastiknn.search.KthGreatestResult;
import org.apache.lucene.index.*;
import org.apache.lucene.util.BytesRef;

Expand Down Expand Up @@ -101,7 +102,7 @@ private DocIdSetIterator buildDocIdSetIterator(HitCounter counter) {
if (counter.isEmpty()) return DocIdSetIterator.empty();
else {

KthGreatest.Result kgr = counter.kthGreatest(candidates);
KthGreatestResult kgr = counter.kthGreatest(candidates);

// Return an iterator over the doc ids >= the min candidate count.
return new DocIdSetIterator() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package com.klibisz.elastiknn.search

import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers

import scala.util.Random

final class ArrayHitCounterSpec extends AnyFreeSpec with Matchers {

final class Reference(capacity: Int) extends HitCounter {
private val counts = scala.collection.mutable.Map[Int, Short](
(0 until capacity).map(_ -> 0.toShort): _*
)

override def increment(key: Int): Unit = counts.update(key, (counts(key) + 1).toShort)

override def increment(key: Int, count: Short): Unit = counts.update(key, (counts(key) + count).toShort)

override def isEmpty: Boolean = !counts.values.exists(_ > 0)

override def get(key: Int): Short = counts(key)

override def numHits(): Int = counts.values.count(_ > 0)

override def capacity(): Int = capacity

override def minKey(): Int = counts.filter(_._2 > 0).keys.min

override def maxKey(): Int = counts.filter(_._2 > 0).keys.max

override def kthGreatest(k: Int): KthGreatestResult = {
val values = counts.values.toArray.sorted.reverse
val numGreaterThan = values.count(_ > values(k))
val numNonZero = values.count(_ != 0)
new KthGreatestResult(values(k), numGreaterThan, numNonZero)
}
}

"reference examples" - {
"example 1" in {
val c = new Reference(10)
c.isEmpty shouldBe true
c.capacity() shouldBe 10

c.get(0) shouldBe 0
c.increment(0)
c.get(0) shouldBe 1
c.numHits() shouldBe 1
c.minKey() shouldBe 0
c.maxKey() shouldBe 0

c.get(5) shouldBe 0
c.increment(5, 5)
c.get(5) shouldBe 5
c.numHits() shouldBe 2
c.minKey() shouldBe 0
c.maxKey() shouldBe 5

c.get(9) shouldBe 0
c.increment(9)
c.get(9) shouldBe 1
c.increment(9)
c.get(9) shouldBe 2
c.numHits() shouldBe 3
c.minKey() shouldBe 0
c.maxKey() shouldBe 9

val kgr = c.kthGreatest(2)
kgr.kthGreatest shouldBe 1
kgr.numGreaterThan shouldBe 2
kgr.numNonZero shouldBe 3
}
}

"randomized comparison to reference" in {
val seed = System.currentTimeMillis()
val rng = new Random(seed)
val numDocs = 60000
val numMatches = numDocs / 2
info(s"Using seed $seed")
for (_ <- 0 until 99) {
val matches = (0 until numMatches).map(_ => rng.nextInt(numDocs))
val ref = new Reference(numDocs)
val ahc = new ArrayHitCounter(numDocs)
matches.foreach { doc =>
ref.increment(doc)
ahc.increment(doc)
ahc.get(doc) shouldBe ref.get(doc)
val count = rng.nextInt(10).toShort
ref.increment(doc, count)
ahc.increment(doc, count)
ahc.get(doc) shouldBe ref.get(doc)
}
ahc.minKey() shouldBe ref.minKey()
ahc.maxKey() shouldBe ref.maxKey()
ahc.numHits() shouldBe ref.numHits()
val k = rng.nextInt(numDocs)
val ahcKgr = ahc.kthGreatest(k)
val refKgr = ref.kthGreatest(k)
ahcKgr shouldBe refKgr
}
}
}
Loading

0 comments on commit ea383d8

Please sign in to comment.