Skip to content

Commit

Permalink
move jautomata helpers into jvmMain
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jun 24, 2024
1 parent b4f31e3 commit 04251bc
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,7 @@ const val ANSI_YELLOW_BACKGROUND = "\u001B[43m"
const val ANSI_BLUE_BACKGROUND = "\u001B[44m"
const val ANSI_PURPLE_BACKGROUND = "\u001B[45m"
const val ANSI_CYAN_BACKGROUND = "\u001B[46m"
const val ANSI_WHITE_BACKGROUND = "\u001B[47m"
const val ANSI_WHITE_BACKGROUND = "\u001B[47m"

fun Char.toUnicodeEscaped() = "\\u${code.toString(16).padStart(4, '0')}"
fun Σᐩ.replaceAll(tbl: Map<String, String>) = tbl.entries.fold(this) { acc, (k, v) -> acc.replace(k, v) }
130 changes: 130 additions & 0 deletions src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package ai.hypergraph.kaliningraph.automata

import ai.hypergraph.kaliningraph.graphs.LabeledGraph
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.markovian.mcmc.MarkovChain
import dk.brics.automaton.Automaton.*
import dk.brics.automaton.Transition
import java.util.*
import kotlin.random.Random
import kotlin.time.*

typealias BState = dk.brics.automaton.State
typealias BAutomaton = dk.brics.automaton.Automaton
typealias JAutomaton<S, K> = net.jhoogland.jautomata.Automaton<S, K>

fun JAutomaton<String, Double>.toDot(processed: MutableSet<Any> = mutableSetOf()) =
LabeledGraph {
val stateQueue = mutableListOf<Any>()
initialStates().forEach { stateQueue.add(it) }
while (true) {
if (stateQueue.isEmpty()) break
val state = stateQueue.removeAt(0)
transitionsOut(state).forEach {
val label = label(it) + "/" + transitionWeight(it).toString().take(4)
val next = this@toDot.to(it)
val initws = initialWeight(state)
val finalws = finalWeight(state)
val initwn = initialWeight(next)
val finalwn = finalWeight(next)
(state.hashCode().toString() + "#$initws/$finalws")[label] = next.hashCode().toString() + "#$initwn/$finalwn"
if (next !in processed) {
processed.add(next)
stateQueue.add(next)
}
}
}
}.toDot()
// States are typically unlabeled in FSA diagrams
.replace("Mrecord\"", "Mrecord\", label=\"\"")
// Final states are suffixed with /1.0 and drawn as double circles
.replace("/1.0\" [\"shape\"=\"Mrecord\"", "/1.0\" [\"shape\"=\"doublecircle\"")
.replace("Mrecord", "circle") // FSA states should be circular
.replace("null", "ε") // null label = ε-transition

/*
* Returns a sequence trajectories through a DFA sampled using the Markov chain.
* The DFA is expected to be deterministic. We use the Markov chain to steer the
* random walk through the DFA by sampling the best transitions conditioned on the
* previous n-1 transitions, i.e., q' ~ argmax_{q'} P(q' | q_{t-1}, ..., q_{t-n+1})
*/

data class FSATrajectory(val toks: List<Σᐩ?>, val lastState: BState, val score: Double) {
val isComplete: Boolean = lastState.isAccept
override fun toString() = toks.reversed().filterNotNull().joinToString(" ")
}

fun PTree.decodeDFA(mc: MarkovChain<Σᐩ>, topK: Int = 10_000_000): List<Σᐩ> = propagator(
both = { a, b -> if (a == null) b else if (b == null) a else a.concatenate(b) },
either = { a, b -> if (a == null) b else if (b == null) a else a.union(b) },
unit = { a ->
if ("ε" in a.root) null
else BAutomaton.makeChar(Random(a.root.hashCode()).nextInt().toChar())
// EditableAutomaton<String, Double>(RealSemiring()).apply {
// val s1 = addState(1.0, 0.0)
// val s2 = addState(0.0, 1.0)
// addTransition(s1, s2, a.root, 1.0)
// }
}
)
// ?.also { println("\n" + Operations.determinizeER(it).toDot().alsoCopy() + "\n") }
// .also { println("Total: ${Automata.transitions(it).size} arcs, ${Automata.states(it).size}") }
// .let { WAutomata.bestStrings(it, maxResults).map { it.label.joinToString(" ") }.toSet() }
?.also { println("Original automata had ${it
.let { "${it.numberOfStates} states and ${it.numberOfTransitions} transitions"}}") }
?.also {
measureTimedValue { BAutomaton.setMinimization(MINIMIZE_BRZOZOWSKI); BAutomaton.minimize(it) }
.also { println("Minimization took ${it.duration}") }.value
// .also { it.toDot().replaceAll(stbl).alsoCopy() }
.also {
// Minimal automata had 92 states and 707 transitions
println("Minimal automata had ${
it.let { "${it.numberOfStates} states and ${it.numberOfTransitions} transitions" }
}")
}
}
// ?.getFiniteStrings(-1)?.map { it.map { ctbl[it] }.joinToString(" ") } ?: emptySet()
?.steerableRandomWalk(
mc = mc,
dec = allTerminals.associateBy { Random(it.hashCode()).nextInt().toChar() },
topK = topK
) ?: emptyList()

// Steers a random walk using the last n-1 transitions from the Markov Chain
fun BAutomaton.steerableRandomWalk(
mc: MarkovChain<Σᐩ>,
// BAutomata uses a Unicode alphabet, and the Markov Chain recognizes a
// string-based alphabet, so we need a way to translate between the two
dec: Map<Char, String>, // Maps unicode characters back to strings
topK: Int // Total number of top-K results to return
): List<Σᐩ> {
val startTime = TimeSource.Monotonic.markNow()
val fullTrajectories = PriorityQueue<FSATrajectory>(compareBy { it.score / it.toks.size })
val partTrajectories = PriorityQueue<FSATrajectory>(compareBy { it.score / it.toks.size })
.apply { add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) }
while (fullTrajectories.size < topK && partTrajectories.isNotEmpty()) {
val partTraj = partTrajectories.remove()
val lastToks = partTraj.toks.take(mc.memory - 1).reversed()
partTraj.lastState.transitions.forEach { next: Transition ->
(next.min..next.max).forEach { tok ->
val decTok = dec[tok]
val nextToks = lastToks + decTok
val nextScore = partTraj.score + mc.scoreChunk(nextToks)
val traj = FSATrajectory(listOf(decTok) + partTraj.toks, next.dest, nextScore)
if (!traj.isComplete) partTrajectories.add(traj)
else {
fullTrajectories.add(traj)
if (traj.lastState.transitions.isNotEmpty())
partTrajectories.add(traj)
}
}
}
}

println("Top 10 trajectories:")
fullTrajectories.take(10).forEach { println(it.score.toString().take(5) + ": $it") }
println("Took ${startTime.elapsedNow()} to decode ${fullTrajectories.size} trajectories")

return fullTrajectories.map { it.toString() }
}

126 changes: 2 additions & 124 deletions src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,16 @@ package ai.hypergraph.kaliningraph.automata

import Grammars
import Grammars.shortS2PParikhMap
import ai.hypergraph.kaliningraph.graphs.LabeledGraph
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.markovian.mcmc.MarkovChain
import dk.brics.automaton.Transition
import net.jhoogland.jautomata.*
import net.jhoogland.jautomata.operations.Concatenation
import net.jhoogland.jautomata.semirings.RealSemiring
import java.io.File
import java.util.PriorityQueue
import kotlin.random.Random
import kotlin.test.*
import kotlin.time.*
import kotlin.time.measureTimedValue

typealias BState = dk.brics.automaton.State
typealias BAutomaton = dk.brics.automaton.Automaton
typealias JAutomaton<S, K> = net.jhoogland.jautomata.Automaton<S, K>

class WFSATest {
val MARKOV_MEMORY = 4
Expand Down Expand Up @@ -75,85 +69,6 @@ class WFSATest {
println(BAutomaton.minimize(ag.also { it.determinize() }).toDot())
}

fun JAutomaton<String, Double>.toDot(processed: MutableSet<Any> = mutableSetOf()) =
LabeledGraph {
val stateQueue = mutableListOf<Any>()
initialStates().forEach { stateQueue.add(it) }
while (true) {
if (stateQueue.isEmpty()) break
val state = stateQueue.removeAt(0)
transitionsOut(state).forEach {
val label = label(it) + "/" + transitionWeight(it).toString().take(4)
val next = this@toDot.to(it)
val initws = initialWeight(state)
val finalws = finalWeight(state)
val initwn = initialWeight(next)
val finalwn = finalWeight(next)
(state.hashCode().toString() + "#$initws/$finalws")[label] = next.hashCode().toString() + "#$initwn/$finalwn"
if (next !in processed) {
processed.add(next)
stateQueue.add(next)
}
}
}
}.toDot()
// States are typically unlabeled in FSA diagrams
.replace("Mrecord\"", "Mrecord\", label=\"\"")
// Final states are suffixed with /1.0 and drawn as double circles
.replace("/1.0\" [\"shape\"=\"Mrecord\"", "/1.0\" [\"shape\"=\"doublecircle\"")
.replace("Mrecord", "circle") // FSA states should be circular
.replace("null", "ε") // null label = ε-transition

/*
* Returns a sequence trajectories through a DFA sampled using the Markov chain.
* The DFA is expected to be deterministic. We use the Markov chain to steer the
* random walk through the DFA by sampling the best transitions conditioned on the
* previous n-1 transitions, i.e., q' ~ argmax_{q'} P(q' | q_{t-1}, ..., q_{t-n+1})
*/

data class FSATrajectory(val toks: List<Σᐩ?>, val lastState: BState, val score: Double) {
val isComplete: Boolean = lastState.isAccept
override fun toString() = toks.reversed().filterNotNull().joinToString(" ")
}

// Steers a random walk using the last n-1 transitions from the Markov Chain
fun BAutomaton.steerableRandomWalk(
mc: MarkovChain<Σᐩ>,
// BAutomata uses a Unicode alphabet, and the Markov Chain recognizes a
// string-based alphabet, so we need a way to translate between the two
dec: Map<Char, String>, // Maps unicode characters back to strings
topK: Int = 10_000_000 // Total number of top-K results to return
): List<Σᐩ> {
val startTime = TimeSource.Monotonic.markNow()
val fullTrajectories = PriorityQueue<FSATrajectory>(compareBy { it.score / it.toks.size })
val partTrajectories = PriorityQueue<FSATrajectory>(compareBy { it.score / it.toks.size })
.apply { add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) }
while (fullTrajectories.size < topK && partTrajectories.isNotEmpty()) {
val partTraj = partTrajectories.remove()
val lastToks = partTraj.toks.take(mc.memory - 1).reversed()
partTraj.lastState.transitions.forEach { next: Transition ->
(next.min..next.max).forEach { tok ->
val decTok = dec[tok]
val nextToks = lastToks + decTok
val nextScore = partTraj.score + mc.scoreChunk(nextToks)
val traj = FSATrajectory(listOf(decTok) + partTraj.toks, next.dest, nextScore)
if (!traj.isComplete) partTrajectories.add(traj)
else {
fullTrajectories.add(traj)
if (traj.lastState.transitions.isNotEmpty())
partTrajectories.add(traj)
}
}
}
}

println("Top 10 trajectories:")
fullTrajectories.take(10).forEach { println(it.score.toString().take(5) + ": $it") }
println("Took ${startTime.elapsedNow()} to decode ${fullTrajectories.size} trajectories")

return fullTrajectories.map { it.toString() }
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.automata.WFSATest.testPTreeVsWFSA"
*/
Expand All @@ -167,50 +82,13 @@ class WFSATest {
val groundTr = "+ NAME : True NEWLINE NAME = STRING NEWLINE NAME = NAME . NAME ( STRING ) NEWLINE"
val radius = 2
val pt = Grammars.seq2parsePythonCFG.makeLevPTree(toRepair, radius, shortS2PParikhMap)
fun Char.toUnicodeEscaped() = "\\u${code.toString(16).padStart(4, '0')}"

val ctbl = Grammars.seq2parsePythonCFG.terminals.associateBy { Random(it.hashCode()).nextInt().toChar() }
val stbl = Grammars.seq2parsePythonCFG.terminals.associateBy { Random(it.hashCode()).nextInt().toChar().toUnicodeEscaped() }
fun Σᐩ.replaceAll(tbl: Map<String, String>) = tbl.entries.fold(this) { acc, (k, v) -> acc.replace(k, v) }

println("Total trees: " + pt.totalTrees.toString())
val maxResults = 10_000
val ptreeRepairs = measureTimedValue {
pt.sampleStrWithoutReplacement().distinct().take(maxResults).toSet()
}
measureTimedValue {
pt.propagator(
both = { a, b -> if (a == null) b else if (b == null) a else a.concatenate(b) },
either = { a, b -> if (a == null) b else if (b == null) a else a.union(b) },
unit = { a ->
if ("ε" in a.root) null
else BAutomaton.makeChar(Random(a.root.hashCode()).nextInt().toChar())
// EditableAutomaton<String, Double>(RealSemiring()).apply {
// val s1 = addState(1.0, 0.0)
// val s2 = addState(0.0, 1.0)
// addTransition(s1, s2, a.root, 1.0)
// }
}
)
// ?.also { println("\n" + Operations.determinizeER(it).toDot().alsoCopy() + "\n") }
// .also { println("Total: ${Automata.transitions(it).size} arcs, ${Automata.states(it).size}") }
// .let { WAutomata.bestStrings(it, maxResults).map { it.label.joinToString(" ") }.toSet() }
?.also { println("Original automata had ${it
.let { "${it.numberOfStates} states and ${it.numberOfTransitions} transitions"}}") }
?.also {
measureTimedValue { BAutomaton.minimize(it) }
.also { println("Minimization took ${it.duration}") }.value
// .also { it.toDot().replaceAll(stbl).alsoCopy() }
.also {
// Minimal automata had 92 states and 707 transitions
println("Minimal automata had ${
it.let { "${it.numberOfStates} states and ${it.numberOfTransitions} transitions" }
}")
}
}
// ?.getFiniteStrings(-1)?.map { it.map { ctbl[it] }.joinToString(" ") } ?: emptySet()
?.steerableRandomWalk(P_BIFI_PY150, ctbl) ?: emptyList()
}.also {
measureTimedValue { pt.decodeDFA(P_BIFI_PY150) }.also {
assertTrue(groundTr in it.value, "Ground truth not found in ${it.value.size} repairs")
println("Index: ${it.value.indexOf(groundTr)}")
// // Print side by side comparison of repairs
Expand Down

0 comments on commit 04251bc

Please sign in to comment.