Skip to content

Commit

Permalink
Remove usages of scala's Array so we can avoid the extra permissions
Browse files Browse the repository at this point in the history
  • Loading branch information
alexklibisz committed Mar 23, 2024
1 parent 2cb1d43 commit fb5e474
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 13 deletions.
2 changes: 1 addition & 1 deletion ann-benchmarks/ann-benchmarks
Submodule ann-benchmarks updated 166 files
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public class FloatArrayBuffer {
private int index = 0;

public FloatArrayBuffer() {
// System.out.printf("Starting at %d\n", nextInitialCapacity);
this.array = new float[nextInitialCapacity];
}

Expand All @@ -28,6 +29,7 @@ public void append(float f) {
// this.array[index - 1] = f;
// }
if (index == this.array.length) {
// System.out.printf("Growing from %d to %d\n", this.array.length, this.array.length * 2);
this.array = Arrays.copyOf(this.array, this.array.length * 2);
}
this.array[index++] = f;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.klibisz.elastiknn;

import org.elasticsearch.index.mapper.FieldMapper;

public class VectorMapperUtil {

public static FieldMapper.Parameter<?>[] EMPTY_ARRAY_FIELD_MAPPER_PARAMETER = new FieldMapper.Parameter[0];

}
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
grant {
permission java.lang.RuntimePermission "getClassLoader";
};
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ object VectorMapper {
else {
val sorted = vec.sorted() // Sort for faster intersections on the query side.
mapping match {
case Mapping.SparseBool(_) => Try(ExactQuery.index(field, sorted))
case Mapping.SparseBool(_) => Try(Seq(ExactQuery.index(field, sorted)))
case m: Mapping.JaccardLsh =>
Try(HashingQuery.index(field, luceneFieldType, sorted, modelCache(m).hash(vec.trueIndices, vec.totalIndices)))
case m: Mapping.HammingLsh =>
Expand All @@ -51,7 +51,7 @@ object VectorMapper {
Failure(ElastiknnException.vectorDimensions(vec.values.length, mapping.dims))
else
mapping match {
case Mapping.DenseFloat(_) => Try(ExactQuery.index(field, vec))
case Mapping.DenseFloat(_) => Try(Seq(ExactQuery.index(field, vec)))
case m: Mapping.CosineLsh => Try(HashingQuery.index(field, luceneFieldType, vec, modelCache(m).hash(vec.values)))
case m: Mapping.L2Lsh => Try(HashingQuery.index(field, luceneFieldType, vec, modelCache(m).hash(vec.values)))
case m: Mapping.PermutationLsh => Try(HashingQuery.index(field, luceneFieldType, vec, modelCache(m).hash(vec.values)))
Expand Down Expand Up @@ -138,6 +138,9 @@ abstract class VectorMapper[V <: Vec: XContentCodec.Decoder] { self =>
override def getMergeBuilder: FieldMapper.Builder = new Builder(simpleName(), mapping)
}

override def getParameters: Array[FieldMapper.Parameter[_]] = Array.empty
override def getParameters: Array[FieldMapper.Parameter[_]] =
// This has to be defined in Java because scala's Array wrapper uses ClassTag,
// which requires the extra permission: java.lang.RuntimePermission "getClassLoader".
VectorMapperUtil.EMPTY_ARRAY_FIELD_MAPPER_PARAMETER
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ final class ExactQuery[V <: Vec, S <: StoredVec](field: String, queryVec: V, sim
}

object ExactQuery {
def index[V <: Vec: StoredVec.Encoder](field: String, vec: V): Seq[IndexableField] = {
def index[V <: Vec: StoredVec.Encoder](field: String, vec: V): IndexableField = {
val storedVec = implicitly[StoredVec.Encoder[V]].apply(vec)
Seq(new BinaryDocValuesField(field, new BytesRef(storedVec)))
new BinaryDocValuesField(field, new BytesRef(storedVec))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.apache.lucene.util.BytesRef
import org.elasticsearch.common.lucene.search.function.{CombineFunction, LeafScoreFunction, ScoreFunction}

import java.util.Objects
import scala.collection.mutable.ListBuffer

final class HashingQuery[V <: Vec, S <: StoredVec: Decoder](
field: String,
Expand Down Expand Up @@ -52,10 +53,15 @@ final class HashingQuery[V <: Vec, S <: StoredVec: Decoder](
private val reader = ctx.reader()
private val terms = reader.terms(field)
private val termsEnum = terms.iterator()
private val postings = hashes.sorted.flatMap { h =>
if (termsEnum.seekExact(new BytesRef(h.hash))) Some(termsEnum.postings(null, PostingsEnum.NONE))
else None
private val postings: Seq[PostingsEnum] = {
val buf = new ListBuffer[PostingsEnum]()
hashes.sorted.foreach { h =>
if (termsEnum.seekExact(new BytesRef(h.hash))) buf.prepend(termsEnum.postings(null, PostingsEnum.NONE))
else None
}
buf.toList.reverse
}

override def score(docId: Int, subQueryScore: Float): Double = {
val intersection = postings.count { p => p.docID() != DocIdSetIterator.NO_MORE_DOCS && p.advance(docId) == docId }
simFunc.maxScore * (intersection * 1d / hashes.length)
Expand Down Expand Up @@ -84,8 +90,11 @@ object HashingQuery {
fieldType: FieldType,
vec: V,
hashes: Array[HashAndFreq]
): Seq[IndexableField] = ExactQuery.index(field, vec) ++ hashes.flatMap { h =>
val f = new Field(field, h.hash, fieldType)
(0 until h.freq).map(_ => f)
): Seq[IndexableField] = {
val buffer = ListBuffer.empty[IndexableField]
hashes.foreach { h =>
(0 until h.freq).foreach(_ => buffer.prepend(new Field(field, h.hash, fieldType)))
}
buffer.prepend(ExactQuery.index(field, vec)).toList
}
}

0 comments on commit fb5e474

Please sign in to comment.