Skip to content

Commit

Permalink
look into optimizing LBH solver
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Dec 29, 2023
1 parent 689e18b commit 6d9ee37
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 218 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,9 @@ fun CFG.makeLevGrammar(source: List<Σᐩ>, distance: Int) =
fun CFG.barHillelRepair(prompt: List<Σᐩ>, distance: Int) =
makeLevGrammar(prompt, distance).enumSeq(List(prompt.size + distance) { "_" })

// Specialized Bar-Hillel construction for Levenshtein FSA
private infix fun CFG.intersectLevFSAP(fsa: FSA): CFG {
var clock = TimeSource.Monotonic.markNow()
val initFinal =
(fsa.init * fsa.final).map { (q, r) -> "START" to listOf("[$q,START,$r]") }

val transits =
fsa.Q.map { (q, a, r) -> "[$q,$a,$r]" to listOf(a) }

fun Triple<Σᐩ, Σᐩ, Σᐩ>.isCompatibleWith(nts: Triple<Σᐩ, Σᐩ, Σᐩ>): Boolean {
fun Σᐩ.coords(): Pair<Int, Int> =
(length / 2 - 1).let { substring(2, it + 2).toInt() to substring(it + 3).toInt() }
Expand Down Expand Up @@ -55,6 +50,12 @@ private infix fun CFG.intersectLevFSAP(fsa: FSA): CFG {
return isCompatible()
}

val initFinal =
(fsa.init * fsa.final).map { (q, r) -> "START" to listOf("[$q,START,$r]") }

val transits =
fsa.Q.map { (q, a, r) -> "[$q,$a,$r]" to listOf(a) }

// For every production A → σ in P, for every (p, σ, q) ∈ Q × Σ × Q
// such that δ(p, σ) = q we have the production [p, A, q] → σ in P′.
val unitProds = unitProdRules(fsa)
Expand Down Expand Up @@ -124,8 +125,9 @@ fun CFG.dropVestigialProductions(
return if (rw.size == size) this else rw.dropVestigialProductions(criteria)
}

infix fun FSA.intersect(cfg: CFG) = cfg.intersect(this)
infix fun FSA.intersect(cfg: CFG) = cfg.freeze().intersect(this)

// Generic Bar-Hillel construction for arbitrary FSA
infix fun CFG.intersect(fsa: FSA): CFG {
val clock = TimeSource.Monotonic.markNow()
val initFinal =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ val CFG.unitReachability by cache {

val CFG.noNonterminalStubs: CFG by cache {
println("Disabling nonterminal stubs!")
filter { it.RHS.none { it.isNonterminalStubIn(this) } }.toSet()
filter { it.RHS.none { it.isNonterminalStubIn(this) } }.toSet().freeze()
.also { rewriteHistory.put(it, freeze().let { rewriteHistory[it]!! + listOf(it)}) }
.also { it.blocked.addAll(blocked) }
}

val CFG.noEpsilonOrNonterminalStubs: CFG by cache {
println("Disabling nonterminal stubs!")
filter { it.RHS.none { it.isNonterminalStubIn(this) } }
.filter { "ε" !in it.toString() }.toSet()
.filter { "ε" !in it.toString() }.toSet().freeze()
.also { rewriteHistory.put(it, freeze().let { rewriteHistory[it]!! + listOf(it)}) }
.also { it.blocked.addAll(blocked) }
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.automata.*
import ai.hypergraph.kaliningraph.tokenizeByWhitespace
import ai.hypergraph.kaliningraph.types.*
import kotlin.math.*

Expand Down Expand Up @@ -155,7 +155,7 @@ fun <T> levenshtein(o1: List<T>, o2: List<T>): Int {
return prev[o2.size]
}

fun levenshteinAlign(a: String, b: String) =
fun levenshteinAlign(a: Σᐩ, b: Σᐩ): List<Pair<Σᐩ?, Σᐩ?>> =
levenshteinAlign(a.tokenizeByWhitespace(), b.tokenizeByWhitespace())

fun <T> levenshteinAlign(a: List<T>, b: List<T>): List<Pair<T?, T?>> {
Expand Down Expand Up @@ -204,4 +204,17 @@ fun <T> levenshteinAlign(a: List<T>, b: List<T>): List<Pair<T?, T?>> {
val revPathA = aPathRev.reversed()
val revPathB = bPathRev.reversed()
return revPathA.zip(revPathB)
}
}

fun <T> List<Pair<T?, T?>>.paintANSIColors(): Σᐩ =
joinToString(" ") { (a, b) ->
when {
// Green (insertion)
a == null -> "$ANSI_GREEN_BACKGROUND$b$ANSI_RESET"
// Red (deletion)
b == null -> "$ANSI_RED_BACKGROUND$a$ANSI_RESET"
// Orange (substitution)
a != b -> "$ANSI_ORANGE_BACKGROUND$b$ANSI_RESET"
else -> b.toString()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ class BarHillelTest {
""".parseCFG().noEpsilonOrNonterminalStubs

val origStr = "1 + 1"
val levFSA = makeLevFSA(origStr, 2, simpleCFG.terminals)
val levDist = 2
val levFSA = makeLevFSA(origStr, levDist, simpleCFG.terminals)

val levCFG = levFSA.intersectLevFSA(simpleCFG)

Expand All @@ -122,8 +123,13 @@ class BarHillelTest {
val testFail = "2 * 2"
assertFalse(testFail in levCFG.language)

val template = List(5) { "_" }
val solutions = levCFG.enumSeq(template).toList().onEach { println(it) }
val template = List(origStr.tokenizeByWhitespace().size + levDist) { "_" }
val solutions = levCFG.enumSeq(template).toList().onEach {
val actDist = levenshtein(origStr, it)
val levAlgn = levenshteinAlign(origStr, it).paintANSIColors()
assertTrue(actDist <= levDist)
println(levAlgn)
}
println("Found ${solutions.size} solutions within Levenshtein distance 2 of \"$origStr\"")
}

Expand Down Expand Up @@ -188,14 +194,40 @@ class BarHillelTest {
"${(lbhSet + efset) - (lbhSet intersect efset)}")
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.BarHillelTest.testIfThenBarHillel"
*/
@Test
fun testIfThenBarHillel() {
val gram = Grammars.ifThen
val origStr = "if ( true or false then true else 1"
val tokens = origStr.tokenizeByWhitespace()
val levDist = 3
val levBall = makeLevFSA(origStr, levDist, gram.terminals)
val intGram = gram.intersectLevFSA(levBall)
val clock = TimeSource.Monotonic.markNow()
val template = List(tokens.size + levDist) { "_" }
intGram.enumSeq(template).distinct().onEach {
val levAlign = levenshteinAlign(origStr, it).paintANSIColors()
val actDist= levenshtein(origStr, it)
println(levAlign)

assertTrue(it in gram.language)
assertTrue(levBall.recognizes(it))
assertTrue(actDist <= levDist)
}.toList().also { println("Found ${it.size} solutions using Levenshtein/Bar-Hillel in ${clock.elapsedNow()}") }
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.BarHillelTest.testPythonBarHillel"
*/
@Test
fun testPythonBarHillel() {
val gram = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val toRepair = "NAME = ( NAME . NAME ( NAME NEWLINE".tokenizeByWhitespace()
val levBall = makeLevFSA(toRepair, 3, gram.terminals)
val origStr = "NAME = ( NAME . NAME ( NAME NEWLINE"
val toRepair = origStr.tokenizeByWhitespace()
val levDist = 3
val levBall = makeLevFSA(toRepair, levDist, gram.terminals)
// println(levBall.toDot())
// throw Exception("")
val intGram = gram.intersectLevFSA(levBall)
Expand All @@ -209,17 +241,19 @@ class BarHillelTest {
// println(intGram.prettyPrint())
val clock = TimeSource.Monotonic.markNow()

val template = List(toRepair.size + 2) { "_" }
val template = List(toRepair.size + levDist - 1) { "_" }

val lbhSet = intGram.enumSeq(template).distinct().onEachIndexed { i, it ->
if (i < 10) {
println(it)
if (i < 100) {
val levAlign = levenshteinAlign(origStr, it).paintANSIColors()
println(levAlign)
val pf = intGram.enumTree(it.tokenizeByWhitespace()).toList()
println("Found " + pf.size + " parse trees")
println(pf.first().prettyPrint())
println("\n\n")
}

assertTrue(levenshtein(origStr, it) <= levDist)
assertTrue(it in gram.language)
assertTrue(levBall.recognizes(it))
}.toList()
Expand All @@ -238,10 +272,10 @@ class BarHillelTest {
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.BarHillelTest.realisticTest"
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.BarHillelTest.semiRealisticTest"
*/
// @Test
fun realisticTest() {
fun semiRealisticTest() {
val gram = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val toRepair = "NAME = NAME . NAME ( [ NUMBER , NUMBER , NUMBER ]".tokenizeByWhitespace()
val levBall = makeLevFSA(toRepair, 1, gram.terminals)
Expand Down
Loading

0 comments on commit 6d9ee37

Please sign in to comment.