Skip to content

Commit

Permalink
Performance: build DocIdSetIterator in ArrayHitCounter to enable futu…
Browse files Browse the repository at this point in the history
…re optimizations (#718)
  • Loading branch information
alexklibisz authored Aug 28, 2024
1 parent bbbaeea commit c5a8e21
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 198 deletions.
9 changes: 8 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ElasticsearchPluginPlugin.autoImport.*
import org.typelevel.sbt.tpolecat.{CiMode, DevMode}
import org.typelevel.scalacoptions.*

Global / scalaVersion := "3.3.3"
Expand All @@ -9,7 +10,13 @@ lazy val CirceVersion = "0.14.9"
lazy val ElasticsearchVersion = "8.15.0"
lazy val Elastic4sVersion = "8.14.1"
lazy val ElastiknnVersion = IO.read(file("version")).strip()
lazy val LuceneVersion = "9.10.0"
lazy val LuceneVersion = "9.11.1"

// Setting this to simplify local development.
// https://github.com/typelevel/sbt-tpolecat/tree/v0.5.1?tab=readme-ov-file#modes
ThisBuild / tpolecatOptionsMode := {
if (sys.env.get("CI").contains("true")) CiMode else DevMode
}

lazy val TestSettings = Seq(
Test / parallelExecution := false,
Expand Down
2 changes: 1 addition & 1 deletion docs/pages/performance/fashion-mnist/plot.b64

Large diffs are not rendered by default.

Binary file modified docs/pages/performance/fashion-mnist/plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 8 additions & 8 deletions docs/pages/performance/fashion-mnist/results.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
|Model|Parameters|Recall|Queries per Second|
|---|---|---|---|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=0|0.379|378.846|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=0|0.447|310.273|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=3|0.635|290.668|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=3|0.717|248.644|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=0|0.767|332.671|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=0|0.847|278.984|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=3|0.922|219.114|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=3|0.960|196.862|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=0|0.378|375.370|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=0|0.447|320.039|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=3|0.635|294.600|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=3|0.716|257.913|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=0|0.767|332.779|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=0|0.846|289.472|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=3|0.921|220.716|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=3|0.960|204.668|
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.klibisz.elastiknn.jmhbenchmarks

import org.openjdk.jmh.annotations._
import org.apache.lucene.util.hppc.IntIntHashMap
import org.apache.lucene.internal.hppc.IntIntHashMap
import org.eclipse.collections.impl.map.mutable.primitive.IntShortHashMap

import scala.util.Random
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package com.klibisz.elastiknn.search;

/**
* Use an array of counts to count hits. The index of the array is the doc id.
* Hopefully there's a way to do this that doesn't require O(num docs in segment) time and memory,
* but so far I haven't found anything on the JVM that's faster than simple arrays of primitives.
*/
public class ArrayHitCounter implements HitCounter {
import org.apache.lucene.search.DocIdSetIterator;

public final class ArrayHitCounter implements HitCounter {

private final short[] counts;
private int numHits;
Expand Down Expand Up @@ -44,38 +41,18 @@ public void increment(int key, short count) {
if (after > maxValue) maxValue = after;
}

@Override
public boolean isEmpty() {
return numHits == 0;
}

@Override
public short get(int key) {
return counts[key];
}

@Override
public int numHits() {
return numHits;
}

@Override
public int capacity() {
return counts.length;
}

@Override
public int minKey() {
return minKey;
}

@Override
public int maxKey() {
return maxKey;
}

@Override
public KthGreatestResult kthGreatest(int k) {
private KthGreatestResult kthGreatest(int k) {
// Find the kth greatest document hit count in O(n) time and O(n) space.
// Though the space is typically negligibly small in practice.
// This implementation exploits the fact that we're specifically counting document hit counts.
Expand Down Expand Up @@ -105,4 +82,70 @@ public KthGreatestResult kthGreatest(int k) {
if (kthGreatest == 0) numGreater = numHits;
return new KthGreatestResult(kthGreatest, numGreater, numHits);
}
}

@Override
public DocIdSetIterator docIdSetIterator(int candidates) {
if (numHits == 0) return DocIdSetIterator.empty();
else {

KthGreatestResult kgr = kthGreatest(candidates);

// Return an iterator over the doc ids >= the min candidate count.
return new DocIdSetIterator() {

// Important that this starts at -1. Need a boolean to denote that it has started iterating.
private int docID = -1;
private boolean started = false;

// Track the number of ids emitted, and the number of ids with count = kgr.kthGreatest emitted.
private int numEmitted = 0;
private int numEq = 0;

@Override
public int docID() {
return docID;
}

@Override
public int nextDoc() {

if (!started) {
started = true;
docID = minKey - 1;
}

// Ensure that docs with count = kgr.kthGreatest are only emitted when there are fewer
// than `candidates` docs with count > kgr.kthGreatest.
while (true) {
if (numEmitted == candidates || docID + 1 > maxKey) {
docID = DocIdSetIterator.NO_MORE_DOCS;
return docID;
} else {
docID++;
if (counts[docID] > kgr.kthGreatest) {
numEmitted++;
return docID;
} else if (counts[docID] == kgr.kthGreatest && numEq < candidates - kgr.numGreaterThan) {
numEq++;
numEmitted++;
return docID;
}
}
}
}

@Override
public int advance(int target) {
while (docID < target) nextDoc();
return docID();
}

@Override
public long cost() {
return maxKey - minKey;
}
};
}
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.klibisz.elastiknn.search;

import org.apache.lucene.search.DocIdSetIterator;

public final class EmptyHitCounter implements HitCounter {

@Override
Expand All @@ -8,38 +10,18 @@ public void increment(int key) {}
@Override
public void increment(int key, short count) {}

@Override
public boolean isEmpty() {
return true;
}

@Override
public short get(int key) {
return 0;
}

@Override
public int numHits() {
return 0;
}

@Override
public int capacity() {
return 0;
}

@Override
public int minKey() {
return 0;
}

@Override
public int maxKey() {
return 0;
}

@Override
public KthGreatestResult kthGreatest(int k) {
return new KthGreatestResult((short) 0, 0, 0);
public DocIdSetIterator docIdSetIterator(int k) {
return DocIdSetIterator.empty();
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.klibisz.elastiknn.search;

import org.apache.lucene.search.DocIdSetIterator;

/**
* Abstraction for counting hits for a particular query.
*/
Expand All @@ -9,18 +11,11 @@ public interface HitCounter {

void increment(int key, short count);

boolean isEmpty();

short get(int key);

int numHits();

int capacity();

int minKey();

int maxKey();

KthGreatestResult kthGreatest(int k);
DocIdSetIterator docIdSetIterator(int k);

}
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
package org.apache.lucene.search;

import com.klibisz.elastiknn.models.HashAndFreq;
import com.klibisz.elastiknn.search.ArrayHitCounter;
import com.klibisz.elastiknn.search.EmptyHitCounter;
import com.klibisz.elastiknn.search.HitCounter;
import com.klibisz.elastiknn.search.KthGreatestResult;
import com.klibisz.elastiknn.search.*;
import org.apache.lucene.index.*;
import org.apache.lucene.util.BytesRef;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;

import static java.lang.Math.max;
import static java.lang.Math.min;

/**
Expand Down Expand Up @@ -64,9 +61,8 @@ private HitCounter countHits(LeafReader reader) throws IOException {
} else {
TermsEnum termsEnum = terms.iterator();
PostingsEnum docs = null;

HitCounter counter = new ArrayHitCounter(reader.maxDoc());
// TODO: Is this the right place to use the live docs bitset to check for deleted docs?
// Bits liveDocs = reader.getLiveDocs();
for (HashAndFreq hf : hashAndFrequencies) {
// We take two different paths here, depending on the frequency of the current hash.
// If the frequency is one, we avoid checking the frequency of matching docs when
Expand All @@ -92,76 +88,6 @@ private HitCounter countHits(LeafReader reader) throws IOException {
}
}

private DocIdSetIterator buildDocIdSetIterator(HitCounter counter) {
// TODO: Add back this logging once log4j mess has settled.
// if (counter.numHits() < candidates) {
// logger.warn(String.format(
// "Found fewer approximate matches [%d] than the requested number of candidates [%d]",
// counter.numHits(), candidates));
// }
if (counter.isEmpty()) return DocIdSetIterator.empty();
else {

KthGreatestResult kgr = counter.kthGreatest(candidates);

// Return an iterator over the doc ids >= the min candidate count.
return new DocIdSetIterator() {

// Important that this starts at -1. Need a boolean to denote that it has started iterating.
private int docID = -1;
private boolean started = false;

// Track the number of ids emitted, and the number of ids with count = kgr.kthGreatest emitted.
private int numEmitted = 0;
private int numEq = 0;

@Override
public int docID() {
return docID;
}

@Override
public int nextDoc() {

if (!started) {
started = true;
docID = counter.minKey() - 1;
}

// Ensure that docs with count = kgr.kthGreatest are only emitted when there are fewer
// than `candidates` docs with count > kgr.kthGreatest.
while (true) {
if (numEmitted == candidates || docID + 1 > counter.maxKey()) {
docID = DocIdSetIterator.NO_MORE_DOCS;
return docID();
} else {
docID++;
if (counter.get(docID) > kgr.kthGreatest) {
numEmitted++;
return docID();
} else if (counter.get(docID) == kgr.kthGreatest && numEq < candidates - kgr.numGreaterThan) {
numEq++;
numEmitted++;
return docID();
}
}
}
}

@Override
public int advance(int target) {
while (docID < target) nextDoc();
return docID();
}

@Override
public long cost() {
return counter.numHits();
}
};
}
}

@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
HitCounter counter = countHits(context.reader());
Expand All @@ -179,7 +105,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
ScoreFunction scoreFunction = scoreFunctionBuilder.apply(context);
LeafReader reader = context.reader();
HitCounter counter = countHits(reader);
DocIdSetIterator disi = buildDocIdSetIterator(counter);
DocIdSetIterator disi = counter.docIdSetIterator(candidates);

return new Scorer(this) {
@Override
Expand Down
Loading

0 comments on commit c5a8e21

Please sign in to comment.