diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/StringUtils.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/StringUtils.kt index c37e89e2..d1cbf710 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/StringUtils.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/StringUtils.kt @@ -128,4 +128,7 @@ const val ANSI_YELLOW_BACKGROUND = "\u001B[43m" const val ANSI_BLUE_BACKGROUND = "\u001B[44m" const val ANSI_PURPLE_BACKGROUND = "\u001B[45m" const val ANSI_CYAN_BACKGROUND = "\u001B[46m" -const val ANSI_WHITE_BACKGROUND = "\u001B[47m" \ No newline at end of file +const val ANSI_WHITE_BACKGROUND = "\u001B[47m" + +fun Char.toUnicodeEscaped() = "\\u${code.toString(16).padStart(4, '0')}" +fun Σᐩ.replaceAll(tbl: Map) = tbl.entries.fold(this) { acc, (k, v) -> acc.replace(k, v) } \ No newline at end of file diff --git a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt new file mode 100644 index 00000000..34acc946 --- /dev/null +++ b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt @@ -0,0 +1,130 @@ +package ai.hypergraph.kaliningraph.automata + +import ai.hypergraph.kaliningraph.graphs.LabeledGraph +import ai.hypergraph.kaliningraph.parsing.* +import ai.hypergraph.markovian.mcmc.MarkovChain +import dk.brics.automaton.Automaton.* +import dk.brics.automaton.Transition +import java.util.* +import kotlin.random.Random +import kotlin.time.* + +typealias BState = dk.brics.automaton.State +typealias BAutomaton = dk.brics.automaton.Automaton +typealias JAutomaton = net.jhoogland.jautomata.Automaton + +fun JAutomaton.toDot(processed: MutableSet = mutableSetOf()) = + LabeledGraph { + val stateQueue = mutableListOf() + initialStates().forEach { stateQueue.add(it) } + while (true) { + if (stateQueue.isEmpty()) break + val state = stateQueue.removeAt(0) + transitionsOut(state).forEach { + val label = label(it) + "/" + transitionWeight(it).toString().take(4) + val next = this@toDot.to(it) + val initws = initialWeight(state) + val finalws = finalWeight(state) + val initwn = initialWeight(next) + val finalwn = finalWeight(next) + (state.hashCode().toString() + "#$initws/$finalws")[label] = next.hashCode().toString() + "#$initwn/$finalwn" + if (next !in processed) { + processed.add(next) + stateQueue.add(next) + } + } + } + }.toDot() + // States are typically unlabeled in FSA diagrams + .replace("Mrecord\"", "Mrecord\", label=\"\"") + // Final states are suffixed with /1.0 and drawn as double circles + .replace("/1.0\" [\"shape\"=\"Mrecord\"", "/1.0\" [\"shape\"=\"doublecircle\"") + .replace("Mrecord", "circle") // FSA states should be circular + .replace("null", "ε") // null label = ε-transition + +/* + * Returns a sequence trajectories through a DFA sampled using the Markov chain. + * The DFA is expected to be deterministic. We use the Markov chain to steer the + * random walk through the DFA by sampling the best transitions conditioned on the + * previous n-1 transitions, i.e., q' ~ argmax_{q'} P(q' | q_{t-1}, ..., q_{t-n+1}) + */ + +data class FSATrajectory(val toks: List<Σᐩ?>, val lastState: BState, val score: Double) { + val isComplete: Boolean = lastState.isAccept + override fun toString() = toks.reversed().filterNotNull().joinToString(" ") +} + +fun PTree.decodeDFA(mc: MarkovChain<Σᐩ>, topK: Int = 10_000_000): List<Σᐩ> = propagator( + both = { a, b -> if (a == null) b else if (b == null) a else a.concatenate(b) }, + either = { a, b -> if (a == null) b else if (b == null) a else a.union(b) }, + unit = { a -> + if ("ε" in a.root) null + else BAutomaton.makeChar(Random(a.root.hashCode()).nextInt().toChar()) +// EditableAutomaton(RealSemiring()).apply { +// val s1 = addState(1.0, 0.0) +// val s2 = addState(0.0, 1.0) +// addTransition(s1, s2, a.root, 1.0) +// } + } +) +// ?.also { println("\n" + Operations.determinizeER(it).toDot().alsoCopy() + "\n") } +// .also { println("Total: ${Automata.transitions(it).size} arcs, ${Automata.states(it).size}") } +// .let { WAutomata.bestStrings(it, maxResults).map { it.label.joinToString(" ") }.toSet() } + ?.also { println("Original automata had ${it + .let { "${it.numberOfStates} states and ${it.numberOfTransitions} transitions"}}") } + ?.also { + measureTimedValue { BAutomaton.setMinimization(MINIMIZE_BRZOZOWSKI); BAutomaton.minimize(it) } + .also { println("Minimization took ${it.duration}") }.value +// .also { it.toDot().replaceAll(stbl).alsoCopy() } + .also { + // Minimal automata had 92 states and 707 transitions + println("Minimal automata had ${ + it.let { "${it.numberOfStates} states and ${it.numberOfTransitions} transitions" } + }") + } + } +// ?.getFiniteStrings(-1)?.map { it.map { ctbl[it] }.joinToString(" ") } ?: emptySet() + ?.steerableRandomWalk( + mc = mc, + dec = allTerminals.associateBy { Random(it.hashCode()).nextInt().toChar() }, + topK = topK + ) ?: emptyList() + +// Steers a random walk using the last n-1 transitions from the Markov Chain +fun BAutomaton.steerableRandomWalk( + mc: MarkovChain<Σᐩ>, + // BAutomata uses a Unicode alphabet, and the Markov Chain recognizes a + // string-based alphabet, so we need a way to translate between the two + dec: Map, // Maps unicode characters back to strings + topK: Int // Total number of top-K results to return +): List<Σᐩ> { + val startTime = TimeSource.Monotonic.markNow() + val fullTrajectories = PriorityQueue(compareBy { it.score / it.toks.size }) + val partTrajectories = PriorityQueue(compareBy { it.score / it.toks.size }) + .apply { add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) } + while (fullTrajectories.size < topK && partTrajectories.isNotEmpty()) { + val partTraj = partTrajectories.remove() + val lastToks = partTraj.toks.take(mc.memory - 1).reversed() + partTraj.lastState.transitions.forEach { next: Transition -> + (next.min..next.max).forEach { tok -> + val decTok = dec[tok] + val nextToks = lastToks + decTok + val nextScore = partTraj.score + mc.scoreChunk(nextToks) + val traj = FSATrajectory(listOf(decTok) + partTraj.toks, next.dest, nextScore) + if (!traj.isComplete) partTrajectories.add(traj) + else { + fullTrajectories.add(traj) + if (traj.lastState.transitions.isNotEmpty()) + partTrajectories.add(traj) + } + } + } + } + + println("Top 10 trajectories:") + fullTrajectories.take(10).forEach { println(it.score.toString().take(5) + ": $it") } + println("Took ${startTime.elapsedNow()} to decode ${fullTrajectories.size} trajectories") + + return fullTrajectories.map { it.toString() } +} + diff --git a/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt b/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt index dff87f42..745002e6 100644 --- a/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt +++ b/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt @@ -2,22 +2,16 @@ package ai.hypergraph.kaliningraph.automata import Grammars import Grammars.shortS2PParikhMap -import ai.hypergraph.kaliningraph.graphs.LabeledGraph import ai.hypergraph.kaliningraph.parsing.* import ai.hypergraph.markovian.mcmc.MarkovChain -import dk.brics.automaton.Transition import net.jhoogland.jautomata.* import net.jhoogland.jautomata.operations.Concatenation import net.jhoogland.jautomata.semirings.RealSemiring import java.io.File -import java.util.PriorityQueue import kotlin.random.Random import kotlin.test.* -import kotlin.time.* +import kotlin.time.measureTimedValue -typealias BState = dk.brics.automaton.State -typealias BAutomaton = dk.brics.automaton.Automaton -typealias JAutomaton = net.jhoogland.jautomata.Automaton class WFSATest { val MARKOV_MEMORY = 4 @@ -75,85 +69,6 @@ class WFSATest { println(BAutomaton.minimize(ag.also { it.determinize() }).toDot()) } - fun JAutomaton.toDot(processed: MutableSet = mutableSetOf()) = - LabeledGraph { - val stateQueue = mutableListOf() - initialStates().forEach { stateQueue.add(it) } - while (true) { - if (stateQueue.isEmpty()) break - val state = stateQueue.removeAt(0) - transitionsOut(state).forEach { - val label = label(it) + "/" + transitionWeight(it).toString().take(4) - val next = this@toDot.to(it) - val initws = initialWeight(state) - val finalws = finalWeight(state) - val initwn = initialWeight(next) - val finalwn = finalWeight(next) - (state.hashCode().toString() + "#$initws/$finalws")[label] = next.hashCode().toString() + "#$initwn/$finalwn" - if (next !in processed) { - processed.add(next) - stateQueue.add(next) - } - } - } - }.toDot() - // States are typically unlabeled in FSA diagrams - .replace("Mrecord\"", "Mrecord\", label=\"\"") - // Final states are suffixed with /1.0 and drawn as double circles - .replace("/1.0\" [\"shape\"=\"Mrecord\"", "/1.0\" [\"shape\"=\"doublecircle\"") - .replace("Mrecord", "circle") // FSA states should be circular - .replace("null", "ε") // null label = ε-transition - - /* - * Returns a sequence trajectories through a DFA sampled using the Markov chain. - * The DFA is expected to be deterministic. We use the Markov chain to steer the - * random walk through the DFA by sampling the best transitions conditioned on the - * previous n-1 transitions, i.e., q' ~ argmax_{q'} P(q' | q_{t-1}, ..., q_{t-n+1}) - */ - - data class FSATrajectory(val toks: List<Σᐩ?>, val lastState: BState, val score: Double) { - val isComplete: Boolean = lastState.isAccept - override fun toString() = toks.reversed().filterNotNull().joinToString(" ") - } - - // Steers a random walk using the last n-1 transitions from the Markov Chain - fun BAutomaton.steerableRandomWalk( - mc: MarkovChain<Σᐩ>, - // BAutomata uses a Unicode alphabet, and the Markov Chain recognizes a - // string-based alphabet, so we need a way to translate between the two - dec: Map, // Maps unicode characters back to strings - topK: Int = 10_000_000 // Total number of top-K results to return - ): List<Σᐩ> { - val startTime = TimeSource.Monotonic.markNow() - val fullTrajectories = PriorityQueue(compareBy { it.score / it.toks.size }) - val partTrajectories = PriorityQueue(compareBy { it.score / it.toks.size }) - .apply { add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) } - while (fullTrajectories.size < topK && partTrajectories.isNotEmpty()) { - val partTraj = partTrajectories.remove() - val lastToks = partTraj.toks.take(mc.memory - 1).reversed() - partTraj.lastState.transitions.forEach { next: Transition -> - (next.min..next.max).forEach { tok -> - val decTok = dec[tok] - val nextToks = lastToks + decTok - val nextScore = partTraj.score + mc.scoreChunk(nextToks) - val traj = FSATrajectory(listOf(decTok) + partTraj.toks, next.dest, nextScore) - if (!traj.isComplete) partTrajectories.add(traj) - else { - fullTrajectories.add(traj) - if (traj.lastState.transitions.isNotEmpty()) - partTrajectories.add(traj) - } - } - } - } - - println("Top 10 trajectories:") - fullTrajectories.take(10).forEach { println(it.score.toString().take(5) + ": $it") } - println("Took ${startTime.elapsedNow()} to decode ${fullTrajectories.size} trajectories") - - return fullTrajectories.map { it.toString() } - } - /* ./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.automata.WFSATest.testPTreeVsWFSA" */ @@ -167,50 +82,13 @@ class WFSATest { val groundTr = "+ NAME : True NEWLINE NAME = STRING NEWLINE NAME = NAME . NAME ( STRING ) NEWLINE" val radius = 2 val pt = Grammars.seq2parsePythonCFG.makeLevPTree(toRepair, radius, shortS2PParikhMap) - fun Char.toUnicodeEscaped() = "\\u${code.toString(16).padStart(4, '0')}" - - val ctbl = Grammars.seq2parsePythonCFG.terminals.associateBy { Random(it.hashCode()).nextInt().toChar() } - val stbl = Grammars.seq2parsePythonCFG.terminals.associateBy { Random(it.hashCode()).nextInt().toChar().toUnicodeEscaped() } - fun Σᐩ.replaceAll(tbl: Map) = tbl.entries.fold(this) { acc, (k, v) -> acc.replace(k, v) } println("Total trees: " + pt.totalTrees.toString()) val maxResults = 10_000 val ptreeRepairs = measureTimedValue { pt.sampleStrWithoutReplacement().distinct().take(maxResults).toSet() } - measureTimedValue { - pt.propagator( - both = { a, b -> if (a == null) b else if (b == null) a else a.concatenate(b) }, - either = { a, b -> if (a == null) b else if (b == null) a else a.union(b) }, - unit = { a -> - if ("ε" in a.root) null - else BAutomaton.makeChar(Random(a.root.hashCode()).nextInt().toChar()) -// EditableAutomaton(RealSemiring()).apply { -// val s1 = addState(1.0, 0.0) -// val s2 = addState(0.0, 1.0) -// addTransition(s1, s2, a.root, 1.0) -// } - } - ) -// ?.also { println("\n" + Operations.determinizeER(it).toDot().alsoCopy() + "\n") } -// .also { println("Total: ${Automata.transitions(it).size} arcs, ${Automata.states(it).size}") } -// .let { WAutomata.bestStrings(it, maxResults).map { it.label.joinToString(" ") }.toSet() } - ?.also { println("Original automata had ${it - .let { "${it.numberOfStates} states and ${it.numberOfTransitions} transitions"}}") } - ?.also { - measureTimedValue { BAutomaton.minimize(it) } - .also { println("Minimization took ${it.duration}") }.value -// .also { it.toDot().replaceAll(stbl).alsoCopy() } - .also { - // Minimal automata had 92 states and 707 transitions - println("Minimal automata had ${ - it.let { "${it.numberOfStates} states and ${it.numberOfTransitions} transitions" } - }") - } - } -// ?.getFiniteStrings(-1)?.map { it.map { ctbl[it] }.joinToString(" ") } ?: emptySet() - ?.steerableRandomWalk(P_BIFI_PY150, ctbl) ?: emptyList() - }.also { + measureTimedValue { pt.decodeDFA(P_BIFI_PY150) }.also { assertTrue(groundTr in it.value, "Ground truth not found in ${it.value.size} repairs") println("Index: ${it.value.indexOf(groundTr)}") // // Print side by side comparison of repairs