Skip to content

Commit

Permalink
levenshtein blanket
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jan 3, 2024
1 parent 7cd51c8 commit 683b86f
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -309,4 +309,38 @@ fun CFG.fastRepairSeq(tokens: List<String>, spacing: Int = 2, holes: Int = 6): S
// ifEmpty {...} is a hack to ensure the sequence emits values at a steady frequency
.flatMap { enumSWOR(it).take(100).ifEmpty { sequenceOf("ε") } }
.map { it.removeEpsilon() }
}.map { if (it.isEmpty()) it else minimizeFix(tokens, it.tokenizeByWhitespace()) { this in language } }
}.map { if (it.isEmpty()) it else minimizeFix(tokens, it.tokenizeByWhitespace()) { this in language } }

fun CFG.fasterRepairSeq(tokens: List<String>, spacing: Int = 2, holes: Int = 6): Sequence<String> {
var levenshteinBlanket = tokens
var blanketSeq = emptySequence<String>().iterator()
val uniformSeq = tokens.intersperse(spacing, "ε").let { prompt ->
prompt.indices.toSet().choose(minOf(holes, prompt.size - 1))
.map { prompt.substituteIndices(it) { _, _ -> "_" } }
// ifEmpty {...} is a hack to ensure the sequence emits values at a steady frequency
.flatMap { enumSeq(it).take(100).ifEmpty { sequenceOf("ε") } }
}.iterator()

return generateSequence {
if (blanketSeq.hasNext() && Random.nextBoolean()) blanketSeq.next()//.also { println("Blanket: $it") }
else if (uniformSeq.hasNext()) uniformSeq.next()//.also { println("Uniform: $it") }
else null
}.map { it.removeEpsilon() }
.filter { it.isNotEmpty() }
.distinct()
.map { minimizeFix(tokens, it.tokenizeByWhitespace()) { this in language } }
.distinct()
.onEach {
val newBlanket = updateLevenshteinBlanket(levenshteinBlanket, it.tokenizeByWhitespace())
if (newBlanket != levenshteinBlanket && "_" in newBlanket) {
levenshteinBlanket = newBlanket
blanketSeq = enumSeqSmart(levenshteinBlanket).iterator()
println("New blanket: ${levenshteinBlanket.joinToString(" ")}")
}
}
}

fun updateLevenshteinBlanket(oldBlanket: List<String>, newRepair: List<String>) =
levenshteinAlign(oldBlanket, newRepair).map { (old, new) ->
if (old == null || new == null || old != new) "_" else old
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
package ai.hypergraph.kaliningraph.repair

import kotlin.math.pow

val contextCSV by lazy { pythonContext.lines().readContextCSV() }

fun List<String>.readContextCSV(diversity: Double = 1.0) =
drop(1).map { it.split(", ") }.associate {
ContextEdit(
type = EditType.valueOf(it[0].trim()),
context = Context(it[1], it[2], it[3]),
newMid = it[4]
) to it[5].trim().toDouble().pow(diversity).toInt().coerceAtLeast(1)
}.let { CEADist(it) }

val pythonContext = """
Type , Left , Old Mid , Right , New Mid , Frequency
INS , NAME , , NAME , '(' , 1293
INS , NEWLINE , , NAME , 99 , 1212
Expand Down Expand Up @@ -5721,4 +5737,5 @@ DEL , 98 , NAME , ',' , , 1
DEL , 98 , ',' , NEWLINE , , 1
INS , NAME , , '(' , ',' , 1
DEL , NAME , 'await' , NAME , , 1
INS , '+' , , ':' , NUMBER , 1
INS , '+' , , ':' , NUMBER , 1
""".trimIndent()
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package ai.hypergraph.kaliningraph.repair

import ai.hypergraph.kaliningraph.parsing.*
import kotlin.random.Random

enum class EditType { INS, DEL, SUB }
data class ContextEdit(val type: EditType, val context: Context, val newMid: String) {
override fun toString(): String = context.run {
Expand All @@ -14,6 +17,7 @@ data class ContextEdit(val type: EditType, val context: Context, val newMid: Str
} + " ))"
}
}

data class CEAProb(val cea: ContextEdit, val idx: Int, val frequency: Int) {
override fun equals(other: Any?): Boolean = when (other) {
is CEAProb -> cea == other.cea && idx == other.idx
Expand All @@ -22,6 +26,7 @@ data class CEAProb(val cea: ContextEdit, val idx: Int, val frequency: Int) {
override fun hashCode(): Int = 31 * cea.hashCode() + idx
override fun toString(): String = "[[ $cea, $idx, $frequency ]]"
}

data class Context(val left: String, val mid: String, val right: String) {
override fun equals(other: Any?) = when (other) {
is Context -> left == other.left && mid == other.mid && right == other.right
Expand All @@ -41,4 +46,105 @@ data class CEADist(val allProbs: Map<ContextEdit, Int>) {
val P_insert = allProbs.filter { it.key.type == EditType.INS }
val P_delSubOnCtx = P_delSub.keys.groupBy { it.context }
val P_insertOnCtx = P_insert.keys.groupBy { it.context }
}
}

fun CFG.contextualRepair(broken: List<String>): Sequence<List<String>> {
val initREAs: List<CEAProb> = contextCSV.relevantEditActions(broken)
// Bonuses for previously sampled edits that produced a valid repair
val bonusProbs = mutableMapOf<ContextEdit, Int>()

// println("Total relevant edit actions: ${initREAs.size}\n${initREAs.take(5).joinToString("\n")}\n...")
val samplerTimeout = 10000L
var (total, uniqueValid) = 0 to 0

return generateSequence { broken }.map {
try { it.sampleEditTrajectoryV0(contextCSV, initREAs,
bonusProbs ) }
catch (e: Exception) {
println(broken.joinToString(" ")); e.printStackTrace(); listOf<String>() to listOf()
}
}.mapNotNull { (finalSeq, edits ) ->
if (finalSeq in language) {
edits.forEach { bonusProbs[it.cea] = (bonusProbs[it.cea] ?: 0) + 1 }

uniqueValid++
finalSeq
}
else null
}.distinct()
}

fun List<String>.sampleEditTrajectoryV0(
ceaDist: CEADist,
initREAs: List<CEAProb>,
// Bonuses for previously sampled edits that produced a valid repair
bonusProbs: Map<ContextEdit, Int>? = null,
lengthCDF: List<Double> = listOf(0.5, 0.8, 1.0)
): Pair<List<String>, List<CEAProb>> {
// First sample the length of the edit trajectory from the length distribution
val rand = Random.nextDouble()
val length = lengthCDF.indexOfFirst { rand < it } + 1

if (initREAs.isEmpty()) return this to listOf()
val ceaProbs = mutableListOf<CEAProb>()
// Now sample an edit trajectory of that length from the edit distribution
var listPrime =
initREAs.normalizeAndSampleV0(bonusProbs)
.also { ceaProbs.add(it) }
.let { applyEditAction(it.cea, it.idx + 1) }

for (i in 1..length) {
val relevantEditActions = ceaDist.relevantEditActions(listPrime)
if (relevantEditActions.isEmpty()) {
// println("$i-th iteration, no relevant edit actions for: ${listPrime.joinToString(" ") { it.toPyRuleName() }}")
return listPrime to ceaProbs
}
val sampledEdit = relevantEditActions.normalizeAndSampleV0(bonusProbs)
.also { ceaProbs.add(it) }
listPrime = listPrime.applyEditAction(sampledEdit.cea, sampledEdit.idx + 1)
}
return listPrime to ceaProbs
}

// Faster than the above
fun List<CEAProb>.normalizeAndSampleV0(bonusProbs: Map<ContextEdit, Int>?): CEAProb {
val cdf: List<Int> = (if (bonusProbs == null) map { it.frequency }
else map { it.frequency + bonusProbs.getOrElse(it.cea) { 0 } * 100 })
.let { freqs ->
val cdf = mutableListOf<Int>()
var sum = 0
for (i in freqs.indices) {
sum += freqs[i]
cdf.add(sum)
}
cdf
}
val sample: Int = Random.nextInt(cdf.last())
return this[cdf.binarySearch(sample).let { if (it < 0) -it - 1 else it }.coerceIn(indices)]
}

fun CEADist.relevantEditActions(snippet: List<String>): List<CEAProb> {
val relevantEditActions = mutableListOf<CEAProb>()
for (i in 0 until snippet.size - 2) {
val ctx = Context(snippet[i], snippet[i + 1], snippet[i + 2])
P_insertOnCtx[Context(ctx.left, "", ctx.mid)]?.forEach {
relevantEditActions.add(CEAProb(it, i, P_insert[it]!!))
}
if (i == snippet.size - 3)
P_insertOnCtx[Context(ctx.mid, "", ctx.right)]?.forEach {
relevantEditActions.add(CEAProb(it, i, P_insert[it]!!))
}
P_delSubOnCtx[ctx]?.forEach {
relevantEditActions.add(CEAProb(it, i, P_delSub[it]!!))
}
}
return relevantEditActions
}

fun List<String>.applyEditAction(cea: ContextEdit, idx: Int): List<String> =
when (cea.type) { // 6409ms, 20%
EditType.INS -> subList(0, idx) + cea.newMid + subList(idx + 1, size) // 17937ms, 55%
EditType.DEL -> subList(0, idx) + subList(idx + 1, size) // 2607ms, 8%
EditType.SUB -> subList(0, idx) + cea.newMid + subList(idx + 1, size) // 5552ms, 17%
}//.also { println("Start:$this\n${cea.type}/${cea.context}/${cea.newMid}/${idx}\nAfter:$it") }

Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,34 @@ class BarHillelTest {
println("Enumerative solver took ${clock.elapsedNow().inWholeMilliseconds}ms")
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.BarHillelTest.levenshteinBlanketTest"
*/
@Test
fun levenshteinBlanketTest() {
val gram = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val origStr= "NAME = NAME . NAME ( [ NUMBER , NUMBER , NUMBER ] NEWLINE"
val toRepair = origStr.tokenizeByWhitespace()
val levDist = 2
val levBall = makeLevFSA(toRepair, levDist, gram.terminals)
val clock = TimeSource.Monotonic.markNow()

val s2pg = Grammars.seq2parsePythonCFG
s2pg.fasterRepairSeq(toRepair, 1, 2)
.onEachIndexed { i, it ->
val levDistance = levenshtein(origStr, it)
if (levDistance <= levDist) {
println("Found ($levDistance): " + levenshteinAlign(origStr, it).paintANSIColors())
assertTrue(it in s2pg.language)
assertTrue(levBall.recognizes(it))
}
}.takeWhile { clock.elapsedNow().inWholeSeconds < 30 }.toList()
.also { println("Found ${it.size} minimal solutions using " +
"Probabilistic repair in ${clock.elapsedNow()}") }

println("Enumerative solver took ${clock.elapsedNow().inWholeMilliseconds}ms")
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.BarHillelTest.testHammingBallRepair"
*/
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import kotlin.time.TimeSource
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.repair.ProbabilisticLBH"
*/
class ProbabilisticLBH {
val ceaDist by lazy { File("src/jvmTest/resources/context_edits.csv").readTrigramStats() }
/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.repair.ProbabilisticLBH.testProbabilisticLBH"
*/
Expand Down Expand Up @@ -83,7 +82,7 @@ class ProbabilisticLBH {
}

val topTerms by lazy {
ceaDist.allProbs.entries
contextCSV.allProbs.entries
.filter { it.key.type != EditType.DEL }
.groupingBy { Grammars.seq2parsePythonCFG.getS2PNT(it.key.newMid) }
.aggregate { _, acc: Int?, it, _ -> (acc ?: 0) + it.value }
Expand Down

0 comments on commit 683b86f

Please sign in to comment.