-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Hashing query performance improvements (1.5-2x faster on benchmarks) (#…
…114) Rewrote custom hashing query (`MatchHashesAndScoreQuery`) in Java. Doesn't necessarily make it faster, rather less likely that you introduce an expensive scala abstraction. Also easier to get help from Lucene users. Made match counting faster by using an array instead of a map. This works because each counter only deals with the consecutive doc ids in a single segment. So instead of a Map from doc id to count, you have an array where the index is the doc id and value is the count. Made candidate identification faster using a similar construct. Since you know the highest possible count is the number of terms, you can us an array to build a histogram of the counts, then traverse from the end of the array to find the kth largest count. Specific timing improvements (p90 benchmark times): - Angular LSH: 121ms -> 50ms - L2 LSH: 18ms -> 11ms - Jaccard LSH: 58ms -> 36ms Still need to understand how the `PrefixCodedTerms` work and if there's any possible optimization.
- Loading branch information
1 parent
30c9e87
commit c75b23f
Showing
18 changed files
with
501 additions
and
271 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
43 changes: 43 additions & 0 deletions
43
core/src/main/java/com/klibisz/elastiknn/storage/BitBuffer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package com.klibisz.elastiknn.storage; | ||
|
||
public interface BitBuffer { | ||
void putOne(); | ||
void putZero(); | ||
byte[] toByteArray(); | ||
|
||
class IntBuffer implements BitBuffer { | ||
|
||
private final byte[] prefix; | ||
private int i = 0; | ||
private int b = 0; | ||
|
||
public IntBuffer(byte[] prefix) { | ||
this.prefix = prefix; | ||
} | ||
|
||
public IntBuffer() { | ||
this.prefix = new byte[0]; | ||
} | ||
|
||
@Override | ||
public void putOne() { | ||
this.b += (1 << this.i); | ||
this.i += 1; | ||
} | ||
|
||
@Override | ||
public void putZero() { | ||
this.i += 1; | ||
} | ||
|
||
@Override | ||
public byte[] toByteArray() { | ||
byte[] barr = UnsafeSerialization.writeInt(b); | ||
byte[] res = new byte[prefix.length + barr.length]; | ||
System.arraycopy(prefix, 0, res, 0, prefix.length); | ||
System.arraycopy(barr, 0, res, prefix.length, barr.length); | ||
return res; | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 0 additions & 29 deletions
29
core/src/main/scala/com/klibisz/elastiknn/storage/BitBuffer.scala
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
180 changes: 180 additions & 0 deletions
180
plugin/src/main/java/org/apache/lucene/search/MatchHashesAndScoreQuery.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
package org.apache.lucene.search; | ||
|
||
import com.klibisz.elastiknn.utils.ArrayUtils; | ||
import org.apache.lucene.index.*; | ||
import org.apache.lucene.util.ArrayUtil; | ||
import org.apache.lucene.util.BytesRef; | ||
|
||
import java.io.IOException; | ||
import java.util.Objects; | ||
import java.util.Set; | ||
import java.util.function.Function; | ||
|
||
public class MatchHashesAndScoreQuery extends Query { | ||
|
||
public interface ScoreFunction { | ||
double score(int docId, int numMatchingHashes); | ||
} | ||
|
||
private final String field; | ||
private final BytesRef[] hashes; | ||
private final int candidates; | ||
private final IndexReader indexReader; | ||
private final Function<LeafReaderContext, ScoreFunction> scoreFunctionBuilder; | ||
private final PrefixCodedTerms prefixCodedTerms; | ||
private final int numDocsInSegment; | ||
|
||
private static PrefixCodedTerms makePrefixCodedTerms(String field, BytesRef[] hashes) { | ||
// PrefixCodedTerms.Builder expects the hashes in sorted order. | ||
ArrayUtil.timSort(hashes); | ||
PrefixCodedTerms.Builder builder = new PrefixCodedTerms.Builder(); | ||
for (BytesRef br : hashes) builder.add(field, br); | ||
return builder.finish(); | ||
} | ||
|
||
public MatchHashesAndScoreQuery(final String field, | ||
final BytesRef[] hashes, | ||
final int candidates, | ||
final IndexReader indexReader, | ||
final Function<LeafReaderContext, ScoreFunction> scoreFunctionBuilder) { | ||
this.field = field; | ||
this.hashes = hashes; | ||
this.candidates = candidates; | ||
this.indexReader = indexReader; | ||
this.scoreFunctionBuilder = scoreFunctionBuilder; | ||
this.prefixCodedTerms = makePrefixCodedTerms(field, hashes); | ||
this.numDocsInSegment = indexReader.numDocs(); | ||
} | ||
|
||
@Override | ||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) { | ||
|
||
return new Weight(this) { | ||
|
||
private short[] countMatches(LeafReaderContext context) throws IOException { | ||
LeafReader reader = context.reader(); | ||
Terms terms = reader.terms(field); | ||
TermsEnum termsEnum = terms.iterator(); | ||
PrefixCodedTerms.TermIterator iterator = prefixCodedTerms.iterator(); | ||
short[] counts = new short[numDocsInSegment]; | ||
PostingsEnum docs = null; | ||
BytesRef term = iterator.next(); | ||
while (term != null) { | ||
if (termsEnum.seekExact(term)) { | ||
docs = termsEnum.postings(docs, PostingsEnum.NONE); | ||
for (int i = 0; i < docs.cost(); i++) { | ||
int docId = docs.nextDoc(); | ||
counts[docId] += 1; | ||
} | ||
} | ||
term = iterator.next(); | ||
} | ||
return counts; | ||
} | ||
|
||
private DocIdSetIterator buildDocIdSetIterator(short[] counts) { | ||
if (candidates >= numDocsInSegment) return DocIdSetIterator.all(indexReader.maxDoc()); | ||
else { | ||
int minCandidateCount = ArrayUtils.kthGreatest(counts, candidates); | ||
// DocIdSetIterator that iterates over the doc ids but only emits the ids >= the min candidate count. | ||
return new DocIdSetIterator() { | ||
|
||
private int doc = 0; | ||
|
||
@Override | ||
public int docID() { | ||
return doc; | ||
} | ||
|
||
@Override | ||
public int nextDoc() { | ||
// Increment doc until it exceeds the min candidate count. | ||
do doc++; | ||
while (doc < counts.length && counts[doc]< minCandidateCount); | ||
if (doc == counts.length) return DocIdSetIterator.NO_MORE_DOCS; | ||
else return docID(); | ||
} | ||
|
||
@Override | ||
public int advance(int target) { | ||
while (doc < target) nextDoc(); | ||
return docID(); | ||
} | ||
|
||
@Override | ||
public long cost() { | ||
return counts.length; | ||
} | ||
}; | ||
} | ||
} | ||
|
||
@Override | ||
public void extractTerms(Set<Term> terms) { } | ||
|
||
@Override | ||
public Explanation explain(LeafReaderContext context, int doc) { | ||
return Explanation.match( 0, "If someone know what this should return, please submit a PR. :)"); | ||
} | ||
|
||
@Override | ||
public Scorer scorer(LeafReaderContext context) throws IOException { | ||
ScoreFunction scoreFunction = scoreFunctionBuilder.apply(context); | ||
short[] counts = countMatches(context); | ||
DocIdSetIterator disi = buildDocIdSetIterator(counts); | ||
|
||
return new Scorer(this) { | ||
@Override | ||
public DocIdSetIterator iterator() { | ||
return disi; | ||
} | ||
|
||
@Override | ||
public float getMaxScore(int upTo) { | ||
return Float.MAX_VALUE; | ||
} | ||
|
||
@Override | ||
public float score() { | ||
return (float) scoreFunction.score(docID(), counts[docID()]); | ||
} | ||
|
||
@Override | ||
public int docID() { | ||
return disi.docID(); | ||
} | ||
}; | ||
} | ||
|
||
@Override | ||
public boolean isCacheable(LeafReaderContext ctx) { | ||
return false; | ||
} | ||
}; | ||
} | ||
|
||
@Override | ||
public String toString(String field) { | ||
return String.format( | ||
"%s for field [%s] with [%d] hashes and [%d] candidates", | ||
this.getClass().getSimpleName(), | ||
this.field, | ||
this.hashes.length, | ||
this.candidates); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object obj) { | ||
if (obj instanceof MatchHashesAndScoreQuery) { | ||
MatchHashesAndScoreQuery q = (MatchHashesAndScoreQuery) obj; | ||
return q.hashCode() == this.hashCode(); | ||
} else { | ||
return false; | ||
} | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(field, hashes, candidates, indexReader, scoreFunctionBuilder); | ||
} | ||
} |
Oops, something went wrong.