Skip to content

Commit

Permalink
specialize grammar to terminal subset preimages
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jan 1, 2024
1 parent ecd33dd commit 5094b5c
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.automata.FSA
import ai.hypergraph.kaliningraph.types.*
import ai.hypergraph.kaliningraph.types.times
import kotlin.time.TimeSource

infix fun FSA.intersectLevFSA(cfg: CFG) = cfg.intersectLevFSA(this)
// http://www.cs.umd.edu/~gasarch/BLOGPAPERS/cfg.pdf#page=2
// https://browse.arxiv.org/pdf/2209.06809.pdf#page=5

infix fun CFG.intersectLevFSA(fsa: FSA): CFG = freeze().intersectLevFSAP(fsa)
infix fun CFG.intersectLevFSA(fsa: FSA): CFG =
subgrammar(fsa.alphabet).intersectLevFSAP(fsa)

fun CFG.makeLevGrammar(source: List<Σᐩ>, distance: Int) =
intersectLevFSA(makeLevFSA(source, distance, terminals))
Expand Down Expand Up @@ -63,7 +65,8 @@ private infix fun CFG.intersectLevFSAP(fsa: FSA): CFG {
// For each production A → BC in P, for every p, q, r ∈ Q,
// we have the production [p,A,r] → [p,B,q] [q,C,r] in P′.
val binaryProds =
nonterminalProductions.map {
nonterminalProductions.mapIndexed { i, it ->
if (i % 10 == 0) println("Finished ${i}/${nonterminalProductions.size} productions")
val triples = fsa.states * fsa.states * fsa.states
val (A, B, C) = it.π1 to it.π2[0] to it.π2[1]
triples
Expand Down Expand Up @@ -143,7 +146,7 @@ infix fun CFG.intersect(fsa: FSA): CFG {
// For each production A → BC in P, for every p, q, r ∈ Q,
// we have the production [p,A,r] → [p,B,q] [q,C,r] in P′.
val binaryProds =
nonterminalProductions.map {
nonterminalProductions.mapIndexed { i, it ->
val triples = fsa.states * fsa.states * fsa.states
val (A, B, C) = it.π1 to it.π2[0] to it.π2[1]
triples.map { (p, q, r) ->
Expand Down
79 changes: 79 additions & 0 deletions src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,87 @@ class BiMap(cfg: CFG) {
}
operator fun get(p: List<Σᐩ>): Set<Σᐩ> = R2LHS[p] ?: emptySet()
operator fun get(p: Σᐩ): Set<List<Σᐩ>> = L2RHS[p] ?: emptySet()
operator fun get(p: Set<Σᐩ>): Set<Σᐩ> = TDEPS.entries.filter { it.value == p }.map { it.key }.toSet()
}

val CFG.mustGenerate by cache { inevitableSymbols() }

fun CFG.inevitableSymbols(map: Map<Σᐩ, Set<Σᐩ>> = emptyMap()): Map<Σᐩ, Set<Σᐩ>> {
val newMap = map.toMutableMap()
symbols.forEach { smb ->
// println("Testing $smb")
bimap.TDEPS[smb]?.forEach { nt ->
// println("Testing $smb -> $nt")
if (bimap[nt].all { smb in it || nt in it }) {
// println("Worked! $nt => $smb")
newMap[nt] = newMap.getOrPut(nt) { emptySet() } + smb +
newMap.getOrPut(smb) { emptySet() }
}
// else {
// if (smb == "NEWLINE")
// println("Failed! $nt !=> $smb, first ${bimap[nt].first { smb !in it }}")
// }
}
}
return if (newMap == map) map else inevitableSymbols(newMap)
}

fun Bln.explain(cfg: CFG, prod: Production, reason: String = "") = this.also{
if(it) {
println("Removed [${prod.LHS} -> ${prod.RHS.joinToString(" ")}] because $reason")
if (cfg.count { it.first == prod.LHS } == 1) println("And no other productions were left for `${prod.LHS}`!")
}
}

fun CFG.removeTerminalsVerbose(allowed: Set<Σᐩ>, otps: Set<Production> = this.terminalUnitProductions, origTerms: Set<Σᐩ> = this.terminals, mustGenerate: Map<Σᐩ, Set<Σᐩ>> = this.mustGenerate): CFG {
val deadNTs = mutableSetOf<Σᐩ>()
val next = toMutableSet().apply { removeAll { prod ->
(
// (prod in otps && (prod.RHS.first() !in allowed))
// .explain(this, prod, "the terminal `${prod.RHS.first()}` is not allowed") ||
(mustGenerate[prod.LHS]?.any { (it in origTerms && it !in allowed)
.explain(this, prod, "LHS value `${prod.LHS}` must generate `$it` and `$it` was not allowed") } == true) ||
prod.RHS.any { rhs -> mustGenerate[rhs]?.any { (it in origTerms && it !in allowed)
.explain(this, prod, "RHS value `$rhs` must generate `$it` and `$it` was not allowed") } == true }
).also { if (it && this.count { it.first == prod.first } == 1) {
println("Added `${prod.first}` to deadNTs!")
deadNTs.add(prod.LHS) }
}
} }

next.removeAll { prod ->
prod.RHS.any { rhs ->
(rhs in deadNTs).explain(next, prod, "the RHS value `$rhs` is a dead NT!") ||
(rhs !in origTerms).explain(next, prod, "the RHS terminal `$rhs` was a chopped NT")
}
}

return if (next.size == size) this else next.removeTerminalsVerbose(allowed, otps, origTerms, mustGenerate)
}

fun CFG.removeTerminals(allowed: Set<Σᐩ>, otps: Set<Production> = this.terminalUnitProductions, origTerms: Set<Σᐩ> = this.terminals, mustGenerate: Map<Σᐩ, Set<Σᐩ>> = this.mustGenerate): CFG {
val deadNTs = mutableSetOf<Σᐩ>()
val next = toMutableSet().apply {
removeAll { prod ->
(
(prod in otps && (prod.RHS.first() !in allowed)) ||
mustGenerate[prod.LHS]?.any { (it in origTerms && it !in allowed) } == true ||
prod.RHS.any { rhs -> mustGenerate[rhs]?.any { (it in origTerms && it !in allowed) } == true }
).also { if (it && count { it.first == prod.first } == 1) deadNTs.add(prod.LHS) }
}
}

next.removeAll { prod -> prod.RHS.any { rhs -> rhs in deadNTs || (rhs in next.terminals && rhs !in origTerms) } }

return if (next.size == size) this else next.removeTerminals(allowed, otps, origTerms, mustGenerate)
}

fun CFG.subgrammar(image: Set<Σᐩ>): CFG =
removeTerminals(image)
.also { rewriteHistory.put(it, freeze().let { rewriteHistory[it]!! + listOf(it)}) }
.freeze()
.also { println("All terminals: ${it.terminals}") }

fun CFG.forestHash(s: Σᐩ) = parseForest(s).map { it.structureEncode() }.hashCode()
fun CFG.nonterminalHash(s: Σᐩ) = s.tokenizeByWhitespace().map { preimage(it) }.hashCode()
fun CFG.preimage(vararg nts: Σᐩ): Set<Σᐩ> = bimap.R2LHS[nts.toList()] ?: emptySet()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.automata.*
import ai.hypergraph.kaliningraph.repair.Patch
import ai.hypergraph.kaliningraph.types.*
import kotlin.math.*

Expand Down Expand Up @@ -74,10 +75,16 @@ fun makeLevFSA(
if ((str.size - i + j).absoluteValue <= dist) finalStates.add(it)
}

FSA(Q, initialStates, finalStates)
FSA(Q, initialStates, finalStates).also { println("Levenshtein automata size: ${Q.size}") }
}

fun pd(i: Int, digits: Int) = i.toString().padStart(digits, '0')
private fun pd(i: Int, digits: Int) = i.toString().padStart(digits, '0')

/**
TODO: upArcs and diagArcs are the most expensive operations taking ~O(2n|Σ|) to construct.
We can probably do much better by only creating arcs that are contextually probable.
See: [ai.hypergraph.kaliningraph.repair.CEAProb]
*/

/*
s∈Σ i∈[0,n] j∈[1,k]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package ai.hypergraph.kaliningraph.repair

enum class EditType { INS, DEL, SUB }
data class ContextEdit(val type: EditType, val context: Context, val newMid: String) {
override fun toString(): String = context.run {
"$type, (( " + when (type) {
EditType.INS -> "$left [${newMid}] $right"
EditType.DEL -> "$left ~${mid}~ $right"
EditType.SUB -> "$left [${mid} -> ${newMid}] $right"
} + " // " + when (type) {
EditType.INS -> "$left [${newMid}] $right"
EditType.DEL -> "$left ~${mid}~ $right"
EditType.SUB -> "$left [${mid} -> ${newMid}] $right"
} + " ))"
}
}
data class CEAProb(val cea: ContextEdit, val idx: Int, val frequency: Int) {
override fun equals(other: Any?): Boolean = when (other) {
is CEAProb -> cea == other.cea && idx == other.idx
else -> false
}
override fun hashCode(): Int = 31 * cea.hashCode() + idx
override fun toString(): String = "[[ $cea, $idx, $frequency ]]"
}
data class Context(val left: String, val mid: String, val right: String) {
override fun equals(other: Any?) = when (other) {
is Context -> left == other.left && mid == other.mid && right == other.right
else -> false
}

override fun hashCode(): Int {
var result = left.hashCode()
result = 31 * result + mid.hashCode()
result = 31 * result + right.hashCode()
return result
}
}

data class CEADist(val allProbs: Map<ContextEdit, Int>) {
val P_delSub = allProbs.filter { it.key.type != EditType.INS }
val P_insert = allProbs.filter { it.key.type == EditType.INS }
val P_delSubOnCtx = P_delSub.keys.groupBy { it.context }
val P_insertOnCtx = P_insert.keys.groupBy { it.context }
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,6 @@ package ai.hypergraph.kaliningraph.repair
import java.io.File
import kotlin.math.pow

enum class EditType { INS, DEL, SUB }
data class ContextEdit(val type: EditType, val context: Context, val newMid: String) {
override fun toString(): String = context.run {
"$type, (( " + when (type) {
EditType.INS -> "$left [${newMid}] $right"
EditType.DEL -> "$left ~${mid}~ $right"
EditType.SUB -> "$left [${mid} -> ${newMid}] $right"
} + " // " + when (type) {
EditType.INS -> "$left [${newMid}] $right"
EditType.DEL -> "$left ~${mid}~ $right"
EditType.SUB -> "$left [${mid} -> ${newMid}] $right"
} + " ))"
}
}
data class CEAProb(val cea: ContextEdit, val idx: Int, val frequency: Int) {
override fun equals(other: Any?): Boolean = when (other) {
is CEAProb -> cea == other.cea && idx == other.idx
else -> false
}
override fun hashCode(): Int = 31 * cea.hashCode() + idx
override fun toString(): String = "[[ $cea, $idx, $frequency ]]"
}
data class Context(val left: String, val mid: String, val right: String) {
override fun equals(other: Any?) = when (other) {
is Context -> left == other.left && mid == other.mid && right == other.right
else -> false
}

override fun hashCode(): Int {
var result = left.hashCode()
result = 31 * result + mid.hashCode()
result = 31 * result + right.hashCode()
return result
}
}

data class CEADist(val allProbs: Map<ContextEdit, Int>) {
val P_delSub = allProbs.filter { it.key.type != EditType.INS }
val P_insert = allProbs.filter { it.key.type == EditType.INS }
val P_delSubOnCtx = P_delSub.keys.groupBy { it.context }
val P_insertOnCtx = P_insert.keys.groupBy { it.context }
}

// Divesity: lower is more diverse, higher is less diverse, 1.0 is natural frequencies
fun File.readTrigramStats(diversity: Double = 1.0): CEADist =
readLines().drop(1).map { it.split(", ") }.associate {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
package ai.hypergraph.kaliningraph.repair

import Grammars
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.kaliningraph.repair.TIMEOUT_MS
import ai.hypergraph.kaliningraph.sampling.*
import ai.hypergraph.kaliningraph.tokenizeByWhitespace
import org.junit.jupiter.api.Test
import parallelize
import repairInParallel
import java.io.File
import java.util.stream.*
import kotlin.streams.*
import kotlin.test.*
import kotlin.time.*
import kotlin.time.TimeSource

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.repair.ProbabilisticLBH"
Expand All @@ -23,5 +18,131 @@ class ProbabilisticLBH {
*/
@Test
fun testProbabilisticLBH() {
Grammars.seq2parsePythonCFG.mustGenerate.entries.forEach {
println(it.key + " -> " + it.value)
}
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.repair.ProbabilisticLBH.testSubgrammarEquivalence"
*/
@Test
fun testSubgrammarEquivalence() {
val terminalImage = setOf<String>() + "NEWLINE" + validPythonStatements.tokenizeByWhitespace().toSet()
val s2pg = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val subgrammar = s2pg.subgrammar(terminalImage)

(validPythonStatements + invalidPythonStatements).lines()
.forEach { assertEquals(s2pg.parse(it), subgrammar.parse(it)) }
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.repair.ProbabilisticLBH.testSubgrammar"
*/
@Test
fun testSubgrammar() {
val terminalImage = setOf<String>() + "NEWLINE" + validPythonStatements.tokenizeByWhitespace().toSet()
val s2pg = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val subgrammar = s2pg.subgrammar(terminalImage)
println("Original size: ${s2pg.size}")
println("Subgrammar size: ${subgrammar.size}")

// println("Must generate:\n${s2pg.mustGenerate.filter { it.value.isNotEmpty() }.entries.joinToString("\n") { "${it.key} -> ${it.value}" }}")
// println("::::::::::::")
// subgrammar.forEach { println("${it.LHS} ->" + it.RHS.joinToString(" ")) }

fun Forest.summarize() = joinToString("\n") { it.root + "-[${it.children.joinToString(","){it.root}}]" }


validPythonStatements.lines().forEach {
val pp ="$it NEWLINE" .also { println(it) }

// val z1= subgrammar.initialUTMatrix(pp.tokenizeByWhitespace()).seekFixpoint().diagonals
// val z2 = s2pg.initialUTMatrix(pp.tokenizeByWhitespace()).seekFixpoint().diagonals
//// .zip(s2pg.initialUTMatrix(pp.tokenizeByWhitespace()).diagonals)
// println(z1.size)
// println(z2.size)
// val lastGoodDiag = z1.indexOfLast { it.any { it.summarize().isNotEmpty() } }
// println(lastGoodDiag)
// val lastGood = z1.last { it.any { it.summarize().isNotEmpty() } }
// println(lastGood.map { it.summarize() }.joinToString("\n"))
// println(z2[lastGoodDiag].map { it.summarize() }.sorted().joinToString("\n"))
//// .first { (a, b) -> a != b }.let { (sgd, s2gd) ->
//// sgd.zip(s2gd).forEach { (f1, f2) -> println(f1.summarize() + "\n" + f2.summarize()) }
//// }
//
//// subgrammar.parseInvalidWithMaximalFragments(pp).forEach { println(it.prettyPrint() + "\n\n") }
// println(s2pg.parse(pp)!!.prettyPrint())
// println(lastGood.first { it.isNotEmpty() }.first().prettyPrint())
assertTrue(pp in s2pg.language, "$it\nnot in Grammars.seq2parsePythonCFG!")
assertTrue(pp in subgrammar.language, "$it\nnot in subgrammar!")
}
subgrammar.sampleSeq(List(20) {"_"}).take(100).forEach { pp ->
assertTrue(pp in Grammars.seq2parsePythonCFG.language, "$pp\nnot in Grammars.seq2parsePythonCFG!")
assertTrue(pp in subgrammar.language, "$pp\nnot in subgrammar!")
}
}

val topTerms by lazy {
ceaDist.allProbs.entries
.filter { it.key.type != EditType.DEL }
.groupingBy { Grammars.seq2parsePythonCFG.getS2PNT(it.key.newMid) }
.aggregate { _, acc: Int?, it, _ -> (acc ?: 0) + it.value }
.map { (k, v) -> k to v }
.sortedBy { -it.second }
// .onEach { println("${it.first}≡${Grammars.seq2parsePythonCFG.bimap[it.first]}: ${it.second}") }
.mapNotNull { Grammars.seq2parsePythonCFG.bimap[it.first].firstOrNull() }
.toSet()
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.repair.ProbabilisticLBH.testCompleteness"
*/
@Test
fun testCompleteness() {
println(Grammars.seq2parsePythonCFG.terminals.size)

invalidPythonStatements.lines().forEach {
assertTrue("$it NEWLINE" !in Grammars.seq2parsePythonCFG.language)
}
validPythonStatements.lines().forEach {
assertTrue("$it NEWLINE" in Grammars.seq2parsePythonCFG.language)
}

val gram = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
val origStr = "NAME = ( NAME . NAME ( NAME NEWLINE"
// invalidPythonStatements.lines().drop(1).forEach {
val clock = TimeSource.Monotonic.markNow()
// val gram = Grammars.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
// val origStr = "$it NEWLINE"
val toRepair = origStr.tokenizeByWhitespace()
val levDist = 2
println("Top terms: ${topTerms.joinToString(", ")}")
val levBall = makeLevFSA(toRepair, levDist, topTerms)
println("Total transitions in FSA: ${levBall.Q.size}")
println("Prompt: $origStr")
println("Alphabet: ${levBall.alphabet}")
val intGram = gram.intersectLevFSA(levBall)
println("Finished intersection in ${clock.elapsedNow()}")

val template = List(toRepair.size + levDist) { "_" }

val lbhSet = intGram.enumSeq(template)
.distinct()
.map { minimizeFix(toRepair, it.tokenizeByWhitespace()) { this in gram.language } }
.distinct()
.onEachIndexed { i, it ->
if (i < 100) println(levenshteinAlign(origStr, it).paintANSIColors())

assertTrue(levenshtein(origStr, it) <= levDist)
assertTrue(it in gram.language)
assertTrue(levBall.recognizes(it))
}.toList()
.also { println("TOTAL REPAIRS (${clock.elapsedNow()}): ${it.size}\n\n") }
}

fun CFG.getS2PNT(string: String) =
(if (string.trim().startsWith("'") && string.trim().endsWith("'"))
bimap[listOf(string.trim().drop(1).dropLast(1))]
else bimap[listOf(string.trim())])
}
Loading

0 comments on commit 5094b5c

Please sign in to comment.