Skip to content

Commit

Permalink
minimize LBH repairs
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Dec 29, 2023
1 parent 6d9ee37 commit 2e5f2bb
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 43 deletions.
121 changes: 121 additions & 0 deletions src/commonMain/kotlin/ai/hypergraph/kaliningraph/repair/PatchUtils.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package ai.hypergraph.kaliningraph.repair

import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.kaliningraph.sampling.choose
import ai.hypergraph.kaliningraph.types.*
import kotlin.math.*
import kotlin.time.TimeSource

val COMMON_BRACKETS = "()[]{}".map { "$it" }.toSet()
fun Σᐩ.defaultTokenizer(): List<Σᐩ> =
split(Regex("[\\(\\)\\[\\]{}]|___".let { "((?<=($it))|(?=($it)))" }))

fun minimizeFix(
broke: Σᐩ,
tokenize: Σᐩ.() -> List<Σᐩ>,
fixed: Σᐩ,
separator: Σᐩ = "",
isValid: Σᐩ.() -> Boolean
): Π3A<Σᐩ> {
// val startTime = TimeSource.Monotonic.markNow()
val (brokeTokens, fixedTokens) = broke.tokenize() to fixed.tokenize()

// val brokeJoin = brokeTokens.joinToString("")
val fixedJoin = fixedTokens.joinToString("")
// val pdiffTok = prettyDiffs(listOf(brokeJoin, fixedJoin), listOf("broken", "original fix"))

val patch: Patch = extractPatch(brokeTokens, fixedTokens)
val time = TimeSource.Monotonic.markNow()
val minEdit = deltaDebug(
patch.changedIndices(),
timeout = { 5 < time.elapsedNow().inWholeSeconds }
) { idxs -> patch.apply(idxs, separator).isValid() }
// deltaDebug only minimizes contiguous chunks, so here we find the minimal configuration of edits
// .minimalSubpatch { patch.apply(this).isValidPython() }

// val pdiff = prettyDiffs(listOf(brokeJoin, minFix), listOf("broken", "minimized fix"))
// if(pdiff.any { it == '\u001B' } && pdiffTok.filter { !it.isWhitespace() } != pdiff.filter { !it.isWhitespace() }) println(pdiffTok + "\n\n" + pdiff)

// println("Reduced from ${patch.changes().size} to ${minEdit.size} edits in ${startTime.elapsedNow().inWholeMilliseconds}ms")

// if(!minFix.isValidPython()) println("Minimized fix is invalid Python: $minFix")

val minfix = patch.apply(minEdit, separator)

return broke to fixedJoin to minfix
}

typealias Edit = Π2A<Σᐩ>
typealias Patch = List<Edit>
val Edit.old: Σᐩ get() = first
// If new is empty, then this is a deletion
val Edit.new: Σᐩ get() = second

// returns when there are at least two types of edits (insertions, deletions, changes) choose 2
fun Patch.isInteresting() = changedIndices().let { ch ->
filterIndexed { index, pair -> index in ch }
.map { (a, b) -> if(b == "") "D" else if(a == "") "I" else "C" }
.toSet().size > 1
}
fun Patch.changedIndices(): List<Int> = indices.filter { this[it].old != this[it].new }

fun Patch.scan(i: Int, direction: Boolean, age: Edit.() -> Σᐩ): Σᐩ? =
(if (direction) (i + 1 until size) else (i - 1 downTo 0))
.firstOrNull { this[it].age() != "" }?.let { this[it].age() }

// Scan [l]eft/[r]ight for first non-empty [n]ew/[o]ld token
fun Patch.sln(i: Int): String = scan(i, false) { new }!!
fun Patch.srn(i: Int): String = scan(i, true) { new }!!
fun Patch.slo(i: Int): String = scan(i, false) { old }!!
fun Patch.sro(i: Int): String = scan(i, true) { old }!!

fun Patch.totalCharacterwiseEditDistance(): Int =
filter { (a, b) -> a != b }
.sumOf { (a, b) -> levenshtein(a, b) }

fun List<Int>.minimalSubpatch(filter: List<Int>.() -> Boolean): List<Int> =
(1..size).asSequence().map { choose(it).map { it.toList() } }
.map { it.filter { it.filter() } }.firstOrNull { it.any() }?.firstOrNull() ?: this

fun Patch.apply(indices: List<Int>, separator: Σᐩ = ""): Σᐩ =
mapIndexed { i, it -> if (i in indices) it.new else it.old }.joinToString(separator)

fun extractPatch(original: List<Σᐩ>, new: List<Σᐩ>): Patch =
levenshteinAlign(original, new).map { (old, new) ->
when {
old == null -> "" to new!!
new == null -> old to ""
else -> old to new
}
}

fun <T> deltaDebug(elements: List<T>, n: Int = 2, timeout: () -> Boolean, checkValid: (List<T>) -> Boolean): List<T> {
// If n granularity is greater than number of tests, then finished, simply return passed in tests
if (elements.size < n || timeout()) { return elements }

// Cut the elements into n equal chunks and try each chunk
val chunkSize = (elements.size.toDouble() / n).roundToInt()

val chunks = elements.windowed(chunkSize, chunkSize, true)

var index = 0
for (chunk in chunks) {
if (timeout()) break
val otherChunk = elements.subList(0, index*chunkSize) +
elements.subList(min((index+1)*chunkSize, elements.size), elements.size)

// Try to other, complement chunk first, with theory that valid elements are closer to end
if (checkValid(otherChunk)) return deltaDebug(otherChunk, 2, timeout, checkValid)

// Check if running this chunk works
if (checkValid(chunk)) return deltaDebug(chunk, 2, timeout, checkValid)
index++
}

// If size is equal to number of chunks, we are finished, cannot go down more
if (elements.size == n) return elements

// If not chunk/complement work, increase granularity and try again
return if (elements.size < n * 2) deltaDebug(elements, elements.size, timeout, checkValid)
else deltaDebug(elements, n * 2, timeout, checkValid)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ai.hypergraph.kaliningraph.parsing
import Grammars
import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.automata.parseFSA
import ai.hypergraph.kaliningraph.repair.*
import ai.hypergraph.kaliningraph.sampling.all
import kotlin.test.*
import kotlin.time.*
Expand Down Expand Up @@ -207,15 +208,23 @@ class BarHillelTest {
val intGram = gram.intersectLevFSA(levBall)
val clock = TimeSource.Monotonic.markNow()
val template = List(tokens.size + levDist) { "_" }
intGram.enumSeq(template).distinct().onEach {
val levAlign = levenshteinAlign(origStr, it).paintANSIColors()
val actDist= levenshtein(origStr, it)
println(levAlign)
intGram.enumSeq(template)
.distinct().map {
minimizeFix(origStr, { tokenizeByWhitespace() }, it,
" ", { this in gram.language })
}.distinctBy { it.third }
.onEachIndexed { i, (_, _, it) ->
val levAlign = levenshteinAlign(origStr, it).paintANSIColors()
val actDist= levenshtein(origStr, it)
println(levAlign)

assertTrue(it in gram.language)
assertTrue(levBall.recognizes(it))
assertTrue(actDist <= levDist)
}.toList().also { println("Found ${it.size} solutions using Levenshtein/Bar-Hillel in ${clock.elapsedNow()}") }
assertTrue(it in gram.language)
assertTrue(levBall.recognizes(it))
assertTrue(actDist <= levDist)
}.toList()
// Found 221 minimal solutions using Levenshtein/Bar-Hillel in 23.28s
.also { println("Found ${it.size} minimal solutions using " +
"Levenshtein/Bar-Hillel in ${clock.elapsedNow()}") }
}

/*
Expand Down Expand Up @@ -243,7 +252,12 @@ class BarHillelTest {

val template = List(toRepair.size + levDist - 1) { "_" }

val lbhSet = intGram.enumSeq(template).distinct().onEachIndexed { i, it ->
val lbhSet = intGram.enumSeq(template)
.distinct().map {
minimizeFix(origStr, { tokenizeByWhitespace() }, it,
" ", { this in gram.language })
}.distinctBy { it.third }
.onEachIndexed { i, (_, _, it) ->
if (i < 100) {
val levAlign = levenshteinAlign(origStr, it).paintANSIColors()
println(levAlign)
Expand All @@ -258,11 +272,9 @@ class BarHillelTest {
assertTrue(levBall.recognizes(it))
}.toList()

// Found 19433 solutions using Levenshtein/Bar-Hillel
// Enumerative solver took 320485ms
// Found 6987 minimal solutions using Levenshtein/Bar-Hillel
// Enumerative solver took 360184ms

println("Found ${lbhSet.size} solutions using Levenshtein/Bar-Hillel")
println("Enumerative solver took ${clock.elapsedNow().inWholeMilliseconds}ms")

// val totalParticipatingNonterminals =
// lbhSet.map { intGram.parseTable(it).data.map { it.map { it.root } } }.flatten().flatten().toSet()
Expand All @@ -274,50 +286,40 @@ class BarHillelTest {
/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.BarHillelTest.semiRealisticTest"
*/
// @Test
@Test
fun semiRealisticTest() {
val gram = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val toRepair = "NAME = NAME . NAME ( [ NUMBER , NUMBER , NUMBER ]".tokenizeByWhitespace()
val levBall = makeLevFSA(toRepair, 1, gram.terminals)
val origStr= "NAME = NAME . NAME ( [ NUMBER , NUMBER , NUMBER ] NEWLINE"
val toRepair = origStr.tokenizeByWhitespace()
val levDist = 2
val levBall = makeLevFSA(toRepair, levDist, gram.terminals)
// println(levBall.toDot())
// throw Exception("")
val intGram = gram.intersectLevFSA(levBall)
// val part= intGram.nonterminals.map { it.substringAfter(',')
// .substringBefore(',') }.toSet().filter { it in gram.nonterminals }
//
// println("Part: $part")
// println("Nopart: ${gram.nonterminals - part}")

// .also { println("LEV ∩ CFG grammar:\n${it.pretty}") }
// println(intGram.prettyPrint())
val clock = TimeSource.Monotonic.markNow()

val template = List(toRepair.size + 2) { "_" }
val template = List(toRepair.size + levDist) { "_" }

val lbhSet = intGram.enumSeq(template).distinct().onEachIndexed { i, it ->
if (i < 10) {
println(it)
val pf = intGram.enumTree(it.tokenizeByWhitespace()).toList()
println("Found " + pf.size + " parse trees")
println(pf.first().prettyPrint())
println("\n\n")
}
val lbhSet = intGram.enumSeq(template)
.distinct().map {
minimizeFix(origStr, { tokenizeByWhitespace() }, it,
" ", { this in gram.language })
}.distinctBy { it.third }
.onEachIndexed { i, (_, _, it) ->

assertTrue(it in gram.language)
assertTrue(levBall.recognizes(it))
}.toList()
if (i < 100) println(levenshteinAlign(origStr, it).paintANSIColors())

// Found 19346 solutions using Levenshtein/Bar-Hillel
// Enumerative solver took 382737ms
assertTrue(levenshtein(origStr, it) <= levDist)
assertTrue(it in gram.language)
assertTrue(levBall.recognizes(it))
}.toList()

// Found 657 solutions using Levenshtein/Bar-Hillel
// Enumerative solver took 113329ms

println("Found ${lbhSet.size} solutions using Levenshtein/Bar-Hillel")
println("Enumerative solver took ${clock.elapsedNow().inWholeMilliseconds}ms")

// val totalParticipatingNonterminals =
// lbhSet.map { intGram.parseTable(it).data.map { it.map { it.root } } }.flatten().flatten().toSet()
//
// println("Participation ratio: " + totalParticipatingNonterminals.size + "/" + intGram.nonterminals.size)
// println(intGram.depGraph.'toDot())
}

/*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.kaliningraph.parsing.Edit
import ai.hypergraph.kaliningraph.repair.*
import ai.hypergraph.kaliningraph.sampling.*
import java.util.concurrent.*
Expand Down

0 comments on commit 2e5f2bb

Please sign in to comment.