From a0b0040d11b24f9ad7f0ea3fe68eaa08eebaf0f4 Mon Sep 17 00:00:00 2001 From: Alex Klibisz Date: Wed, 28 Aug 2024 10:10:00 -0700 Subject: [PATCH] Some progress on the tests --- .../elastiknn/search/ArrayHitCounter.java | 3 + .../search/ArrayHitCounterSpec.scala | 83 +++++++++++++------ 2 files changed, 62 insertions(+), 24 deletions(-) diff --git a/elastiknn-lucene/src/main/java/com/klibisz/elastiknn/search/ArrayHitCounter.java b/elastiknn-lucene/src/main/java/com/klibisz/elastiknn/search/ArrayHitCounter.java index 827a80f7..8bfb3440 100644 --- a/elastiknn-lucene/src/main/java/com/klibisz/elastiknn/search/ArrayHitCounter.java +++ b/elastiknn-lucene/src/main/java/com/klibisz/elastiknn/search/ArrayHitCounter.java @@ -90,6 +90,8 @@ public DocIdSetIterator docIdSetIterator(int candidates) { KthGreatestResult kgr = kthGreatest(candidates); + System.out.printf("kth greatest = %d\n", kgr.kthGreatest); + // Return an iterator over the doc ids >= the min candidate count. return new DocIdSetIterator() { @@ -122,6 +124,7 @@ public int nextDoc() { return docID; } else { docID++; + System.out.printf("docID=%d, counts[%d]=%d\n", docID, docID, counts[docID]); if (counts[docID] > kgr.kthGreatest) { numEmitted++; return docID; diff --git a/elastiknn-lucene/src/test/scala/com/klibisz/elastiknn/search/ArrayHitCounterSpec.scala b/elastiknn-lucene/src/test/scala/com/klibisz/elastiknn/search/ArrayHitCounterSpec.scala index 14df616b..3effe98c 100644 --- a/elastiknn-lucene/src/test/scala/com/klibisz/elastiknn/search/ArrayHitCounterSpec.scala +++ b/elastiknn-lucene/src/test/scala/com/klibisz/elastiknn/search/ArrayHitCounterSpec.scala @@ -4,14 +4,28 @@ import org.apache.lucene.search.DocIdSetIterator import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers +import scala.collection.mutable.ArrayBuffer import scala.util.Random final class ArrayHitCounterSpec extends AnyFreeSpec with Matchers { + final class ReferenceDocIdSetIterator(docIds: Array[Int]) extends DocIdSetIterator { + + private var currentDocIdIndex = -1; + override def docID(): Int = if (currentDocIdIndex < docIds.length) docIds(currentDocIdIndex) else DocIdSetIterator.NO_MORE_DOCS + override def nextDoc(): Int = { + currentDocIdIndex += 1 + docID() + } + override def advance(target: Int): Int = { + while (docID() < target) nextDoc() + docID() + } + override def cost(): Long = docIds.length + } + final class Reference(referenceCapacity: Int) extends HitCounter { - private val counts = scala.collection.mutable.Map[Int, Short]( - (0 until referenceCapacity).map(_ -> 0.toShort): _* - ) + private val counts = scala.collection.mutable.Map[Int, Short]().withDefaultValue(0) override def increment(key: Int): Unit = counts.update(key, (counts(key) + 1).toShort) @@ -21,7 +35,25 @@ final class ArrayHitCounterSpec extends AnyFreeSpec with Matchers { override def capacity(): Int = this.referenceCapacity - override def docIdSetIterator(k: Int): DocIdSetIterator = DocIdSetIterator.empty() + override def docIdSetIterator(k: Int): DocIdSetIterator = { + // A very naive/inefficient way to implement the DocIdSetIterator. + val valuesSorted = counts.values.toArray.sorted.reverse + val kthGreatest = valuesSorted.take(k).last + val greaterDocIds = counts.filter(_._2 > kthGreatest).keys.toArray + val equalDocIds = counts.filter(_._2 == kthGreatest).keys.toArray.sorted.take(k - greaterDocIds.length) + val selectedDocIds = (equalDocIds ++ greaterDocIds).sorted + println(counts.toList.sorted) + + new ReferenceDocIdSetIterator(selectedDocIds) + } + } + + private def consumeDocIdSetIterator(disi: DocIdSetIterator): List[Int] = { + val docIds = new ArrayBuffer[Int] + while (disi.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + docIds.append(disi.docID()) + } + docIds.toList } "reference examples" - { @@ -33,27 +65,26 @@ final class ArrayHitCounterSpec extends AnyFreeSpec with Matchers { c.increment(0) c.get(0) shouldBe 1 - c.get(5) shouldBe 0 - c.increment(5, 5) - c.get(5) shouldBe 5 + c.get(1) shouldBe 0 + c.increment(1, 5) + c.get(1) shouldBe 5 - c.get(9) shouldBe 0 - c.increment(9) - c.get(9) shouldBe 1 - c.increment(9) - c.get(9) shouldBe 2 + c.get(2) shouldBe 0 + c.increment(2) + c.get(2) shouldBe 1 + c.increment(2) + c.get(2) shouldBe 2 -// val kgr = c.kthGreatest(2) -// kgr.kthGreatest shouldBe 1 -// kgr.numGreaterThan shouldBe 2 -// kgr.numNonZero shouldBe 3 + // The k=2 most frequent doc IDs are 1 and 2. + val docIds = consumeDocIdSetIterator(c.docIdSetIterator(2)) + docIds shouldBe List(1, 2) } } "randomized comparison to reference" in { - val seed = System.currentTimeMillis() + val seed = 0L // System.currentTimeMillis() val rng = new Random(seed) - val numDocs = 60000 + val numDocs = 10 val numMatches = numDocs / 2 info(s"Using seed $seed") for (_ <- 0 until 99) { @@ -69,12 +100,16 @@ final class ArrayHitCounterSpec extends AnyFreeSpec with Matchers { ahc.increment(doc, count) ahc.get(doc) shouldBe ref.get(doc) } -// ahc.minKey() shouldBe ref.minKey() -// ahc.maxKey() shouldBe ref.maxKey() -// val k = rng.nextInt(numDocs) -// val ahcKgr = ahc.kthGreatest(k) -// val refKgr = ref.kthGreatest(k) -// ahcKgr shouldBe refKgr + val k = rng.nextInt(numDocs) + val actualDocIds = consumeDocIdSetIterator(ahc.docIdSetIterator(k)) + val referenceDocIds = consumeDocIdSetIterator(ref.docIdSetIterator(k)) + + println(k) + println((actualDocIds.length, actualDocIds)) + println((referenceDocIds.length, referenceDocIds)) + println("---") + + referenceDocIds shouldBe actualDocIds } } }