Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance: build DocIdSetIterator in ArrayHitCounter to enable future optimizations #718

Merged
merged 15 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ElasticsearchPluginPlugin.autoImport.*
import org.typelevel.sbt.tpolecat.{CiMode, DevMode}
import org.typelevel.scalacoptions.*

Global / scalaVersion := "3.3.3"
Expand All @@ -9,7 +10,13 @@ lazy val CirceVersion = "0.14.9"
lazy val ElasticsearchVersion = "8.15.0"
lazy val Elastic4sVersion = "8.14.1"
lazy val ElastiknnVersion = IO.read(file("version")).strip()
lazy val LuceneVersion = "9.10.0"
lazy val LuceneVersion = "9.11.1"

// Setting this to simplify local development.
// https://github.com/typelevel/sbt-tpolecat/tree/v0.5.1?tab=readme-ov-file#modes
ThisBuild / tpolecatOptionsMode := {
if (sys.env.get("CI").contains("true")) CiMode else DevMode
}

lazy val TestSettings = Seq(
Test / parallelExecution := false,
Expand Down
2 changes: 1 addition & 1 deletion docs/pages/performance/fashion-mnist/plot.b64

Large diffs are not rendered by default.

Binary file modified docs/pages/performance/fashion-mnist/plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 8 additions & 8 deletions docs/pages/performance/fashion-mnist/results.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
|Model|Parameters|Recall|Queries per Second|
|---|---|---|---|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=0|0.379|378.846|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=0|0.447|310.273|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=3|0.635|290.668|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=3|0.717|248.644|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=0|0.767|332.671|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=0|0.847|278.984|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=3|0.922|219.114|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=3|0.960|196.862|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=0|0.378|375.370|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=0|0.447|320.039|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=3|0.635|294.600|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=3|0.716|257.913|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=0|0.767|332.779|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=0|0.846|289.472|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=3|0.921|220.716|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=3|0.960|204.668|
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.klibisz.elastiknn.jmhbenchmarks

import org.openjdk.jmh.annotations._
import org.apache.lucene.util.hppc.IntIntHashMap
import org.apache.lucene.internal.hppc.IntIntHashMap
import org.eclipse.collections.impl.map.mutable.primitive.IntShortHashMap

import scala.util.Random
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package com.klibisz.elastiknn.search;

/**
* Use an array of counts to count hits. The index of the array is the doc id.
* Hopefully there's a way to do this that doesn't require O(num docs in segment) time and memory,
* but so far I haven't found anything on the JVM that's faster than simple arrays of primitives.
*/
public class ArrayHitCounter implements HitCounter {
import org.apache.lucene.search.DocIdSetIterator;

public final class ArrayHitCounter implements HitCounter {

private final short[] counts;
private int numHits;
Expand Down Expand Up @@ -44,38 +41,18 @@ public void increment(int key, short count) {
if (after > maxValue) maxValue = after;
}

@Override
public boolean isEmpty() {
return numHits == 0;
}

@Override
public short get(int key) {
return counts[key];
}

@Override
public int numHits() {
return numHits;
}

@Override
public int capacity() {
return counts.length;
}

@Override
public int minKey() {
return minKey;
}

@Override
public int maxKey() {
return maxKey;
}

@Override
public KthGreatestResult kthGreatest(int k) {
private KthGreatestResult kthGreatest(int k) {
// Find the kth greatest document hit count in O(n) time and O(n) space.
// Though the space is typically negligibly small in practice.
// This implementation exploits the fact that we're specifically counting document hit counts.
Expand Down Expand Up @@ -105,4 +82,70 @@ public KthGreatestResult kthGreatest(int k) {
if (kthGreatest == 0) numGreater = numHits;
return new KthGreatestResult(kthGreatest, numGreater, numHits);
}
}

@Override
public DocIdSetIterator docIdSetIterator(int candidates) {
if (numHits == 0) return DocIdSetIterator.empty();
else {

KthGreatestResult kgr = kthGreatest(candidates);

// Return an iterator over the doc ids >= the min candidate count.
return new DocIdSetIterator() {

// Important that this starts at -1. Need a boolean to denote that it has started iterating.
private int docID = -1;
private boolean started = false;

// Track the number of ids emitted, and the number of ids with count = kgr.kthGreatest emitted.
private int numEmitted = 0;
private int numEq = 0;

@Override
public int docID() {
return docID;
}

@Override
public int nextDoc() {

if (!started) {
started = true;
docID = minKey - 1;
}

// Ensure that docs with count = kgr.kthGreatest are only emitted when there are fewer
// than `candidates` docs with count > kgr.kthGreatest.
while (true) {
if (numEmitted == candidates || docID + 1 > maxKey) {
docID = DocIdSetIterator.NO_MORE_DOCS;
return docID;
} else {
docID++;
if (counts[docID] > kgr.kthGreatest) {
numEmitted++;
return docID;
} else if (counts[docID] == kgr.kthGreatest && numEq < candidates - kgr.numGreaterThan) {
numEq++;
numEmitted++;
return docID;
}
}
}
}

@Override
public int advance(int target) {
while (docID < target) nextDoc();
return docID();
}

@Override
public long cost() {
return maxKey - minKey;
}
};
}
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.klibisz.elastiknn.search;

import org.apache.lucene.search.DocIdSetIterator;

public final class EmptyHitCounter implements HitCounter {

@Override
Expand All @@ -8,38 +10,18 @@ public void increment(int key) {}
@Override
public void increment(int key, short count) {}

@Override
public boolean isEmpty() {
return true;
}

@Override
public short get(int key) {
return 0;
}

@Override
public int numHits() {
return 0;
}

@Override
public int capacity() {
return 0;
}

@Override
public int minKey() {
return 0;
}

@Override
public int maxKey() {
return 0;
}

@Override
public KthGreatestResult kthGreatest(int k) {
return new KthGreatestResult((short) 0, 0, 0);
public DocIdSetIterator docIdSetIterator(int k) {
return DocIdSetIterator.empty();
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.klibisz.elastiknn.search;

import org.apache.lucene.search.DocIdSetIterator;

/**
* Abstraction for counting hits for a particular query.
*/
Expand All @@ -9,18 +11,11 @@ public interface HitCounter {

void increment(int key, short count);

boolean isEmpty();

short get(int key);

int numHits();

int capacity();

int minKey();

int maxKey();

KthGreatestResult kthGreatest(int k);
DocIdSetIterator docIdSetIterator(int k);

}
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
package org.apache.lucene.search;

import com.klibisz.elastiknn.models.HashAndFreq;
import com.klibisz.elastiknn.search.ArrayHitCounter;
import com.klibisz.elastiknn.search.EmptyHitCounter;
import com.klibisz.elastiknn.search.HitCounter;
import com.klibisz.elastiknn.search.KthGreatestResult;
import com.klibisz.elastiknn.search.*;
import org.apache.lucene.index.*;
import org.apache.lucene.util.BytesRef;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;

import static java.lang.Math.max;
import static java.lang.Math.min;

/**
Expand Down Expand Up @@ -64,9 +61,8 @@ private HitCounter countHits(LeafReader reader) throws IOException {
} else {
TermsEnum termsEnum = terms.iterator();
PostingsEnum docs = null;

HitCounter counter = new ArrayHitCounter(reader.maxDoc());
// TODO: Is this the right place to use the live docs bitset to check for deleted docs?
// Bits liveDocs = reader.getLiveDocs();
for (HashAndFreq hf : hashAndFrequencies) {
// We take two different paths here, depending on the frequency of the current hash.
// If the frequency is one, we avoid checking the frequency of matching docs when
Expand All @@ -92,76 +88,6 @@ private HitCounter countHits(LeafReader reader) throws IOException {
}
}

private DocIdSetIterator buildDocIdSetIterator(HitCounter counter) {
// TODO: Add back this logging once log4j mess has settled.
// if (counter.numHits() < candidates) {
// logger.warn(String.format(
// "Found fewer approximate matches [%d] than the requested number of candidates [%d]",
// counter.numHits(), candidates));
// }
if (counter.isEmpty()) return DocIdSetIterator.empty();
else {

KthGreatestResult kgr = counter.kthGreatest(candidates);

// Return an iterator over the doc ids >= the min candidate count.
return new DocIdSetIterator() {

// Important that this starts at -1. Need a boolean to denote that it has started iterating.
private int docID = -1;
private boolean started = false;

// Track the number of ids emitted, and the number of ids with count = kgr.kthGreatest emitted.
private int numEmitted = 0;
private int numEq = 0;

@Override
public int docID() {
return docID;
}

@Override
public int nextDoc() {

if (!started) {
started = true;
docID = counter.minKey() - 1;
}

// Ensure that docs with count = kgr.kthGreatest are only emitted when there are fewer
// than `candidates` docs with count > kgr.kthGreatest.
while (true) {
if (numEmitted == candidates || docID + 1 > counter.maxKey()) {
docID = DocIdSetIterator.NO_MORE_DOCS;
return docID();
} else {
docID++;
if (counter.get(docID) > kgr.kthGreatest) {
numEmitted++;
return docID();
} else if (counter.get(docID) == kgr.kthGreatest && numEq < candidates - kgr.numGreaterThan) {
numEq++;
numEmitted++;
return docID();
}
}
}
}

@Override
public int advance(int target) {
while (docID < target) nextDoc();
return docID();
}

@Override
public long cost() {
return counter.numHits();
}
};
}
}

@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
HitCounter counter = countHits(context.reader());
Expand All @@ -179,7 +105,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
ScoreFunction scoreFunction = scoreFunctionBuilder.apply(context);
LeafReader reader = context.reader();
HitCounter counter = countHits(reader);
DocIdSetIterator disi = buildDocIdSetIterator(counter);
DocIdSetIterator disi = counter.docIdSetIterator(candidates);

return new Scorer(this) {
@Override
Expand Down
Loading