Skip to content

Commit

Permalink
Some progress on the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexklibisz committed Aug 28, 2024
1 parent 56348f2 commit a0b0040
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ public DocIdSetIterator docIdSetIterator(int candidates) {

KthGreatestResult kgr = kthGreatest(candidates);

System.out.printf("kth greatest = %d\n", kgr.kthGreatest);

// Return an iterator over the doc ids >= the min candidate count.
return new DocIdSetIterator() {

Expand Down Expand Up @@ -122,6 +124,7 @@ public int nextDoc() {
return docID;
} else {
docID++;
System.out.printf("docID=%d, counts[%d]=%d\n", docID, docID, counts[docID]);
if (counts[docID] > kgr.kthGreatest) {
numEmitted++;
return docID;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,28 @@ import org.apache.lucene.search.DocIdSetIterator
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers

import scala.collection.mutable.ArrayBuffer
import scala.util.Random

final class ArrayHitCounterSpec extends AnyFreeSpec with Matchers {

final class ReferenceDocIdSetIterator(docIds: Array[Int]) extends DocIdSetIterator {

private var currentDocIdIndex = -1;
override def docID(): Int = if (currentDocIdIndex < docIds.length) docIds(currentDocIdIndex) else DocIdSetIterator.NO_MORE_DOCS
override def nextDoc(): Int = {
currentDocIdIndex += 1
docID()
}
override def advance(target: Int): Int = {
while (docID() < target) nextDoc()
docID()
}
override def cost(): Long = docIds.length
}

final class Reference(referenceCapacity: Int) extends HitCounter {
private val counts = scala.collection.mutable.Map[Int, Short](
(0 until referenceCapacity).map(_ -> 0.toShort): _*
)
private val counts = scala.collection.mutable.Map[Int, Short]().withDefaultValue(0)

override def increment(key: Int): Unit = counts.update(key, (counts(key) + 1).toShort)

Expand All @@ -21,7 +35,25 @@ final class ArrayHitCounterSpec extends AnyFreeSpec with Matchers {

override def capacity(): Int = this.referenceCapacity

override def docIdSetIterator(k: Int): DocIdSetIterator = DocIdSetIterator.empty()
override def docIdSetIterator(k: Int): DocIdSetIterator = {
// A very naive/inefficient way to implement the DocIdSetIterator.
val valuesSorted = counts.values.toArray.sorted.reverse
val kthGreatest = valuesSorted.take(k).last
val greaterDocIds = counts.filter(_._2 > kthGreatest).keys.toArray
val equalDocIds = counts.filter(_._2 == kthGreatest).keys.toArray.sorted.take(k - greaterDocIds.length)
val selectedDocIds = (equalDocIds ++ greaterDocIds).sorted
println(counts.toList.sorted)

new ReferenceDocIdSetIterator(selectedDocIds)
}
}

private def consumeDocIdSetIterator(disi: DocIdSetIterator): List[Int] = {
val docIds = new ArrayBuffer[Int]
while (disi.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
docIds.append(disi.docID())
}
docIds.toList
}

"reference examples" - {
Expand All @@ -33,27 +65,26 @@ final class ArrayHitCounterSpec extends AnyFreeSpec with Matchers {
c.increment(0)
c.get(0) shouldBe 1

c.get(5) shouldBe 0
c.increment(5, 5)
c.get(5) shouldBe 5
c.get(1) shouldBe 0
c.increment(1, 5)
c.get(1) shouldBe 5

c.get(9) shouldBe 0
c.increment(9)
c.get(9) shouldBe 1
c.increment(9)
c.get(9) shouldBe 2
c.get(2) shouldBe 0
c.increment(2)
c.get(2) shouldBe 1
c.increment(2)
c.get(2) shouldBe 2

// val kgr = c.kthGreatest(2)
// kgr.kthGreatest shouldBe 1
// kgr.numGreaterThan shouldBe 2
// kgr.numNonZero shouldBe 3
// The k=2 most frequent doc IDs are 1 and 2.
val docIds = consumeDocIdSetIterator(c.docIdSetIterator(2))
docIds shouldBe List(1, 2)
}
}

"randomized comparison to reference" in {
val seed = System.currentTimeMillis()
val seed = 0L // System.currentTimeMillis()
val rng = new Random(seed)
val numDocs = 60000
val numDocs = 10
val numMatches = numDocs / 2
info(s"Using seed $seed")
for (_ <- 0 until 99) {
Expand All @@ -69,12 +100,16 @@ final class ArrayHitCounterSpec extends AnyFreeSpec with Matchers {
ahc.increment(doc, count)
ahc.get(doc) shouldBe ref.get(doc)
}
// ahc.minKey() shouldBe ref.minKey()
// ahc.maxKey() shouldBe ref.maxKey()
// val k = rng.nextInt(numDocs)
// val ahcKgr = ahc.kthGreatest(k)
// val refKgr = ref.kthGreatest(k)
// ahcKgr shouldBe refKgr
val k = rng.nextInt(numDocs)
val actualDocIds = consumeDocIdSetIterator(ahc.docIdSetIterator(k))
val referenceDocIds = consumeDocIdSetIterator(ref.docIdSetIterator(k))

println(k)
println((actualDocIds.length, actualDocIds))
println((referenceDocIds.length, referenceDocIds))
println("---")

referenceDocIds shouldBe actualDocIds
}
}
}

0 comments on commit a0b0040

Please sign in to comment.