Skip to content

Commit

Permalink
improve handling for AFSAs
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jan 27, 2025
1 parent eed4652 commit a8a8a1a
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 19 deletions.
38 changes: 37 additions & 1 deletion src/commonMain/kotlin/ai/hypergraph/kaliningraph/CommonUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,40 @@ infix fun Any.isA(that: Any) = when {
}

infix fun Collection<Any>.allAre(that: Any) = all { it isA that }
infix fun Collection<Any>.anyAre(that: Any) = any { it isA that }
infix fun Collection<Any>.anyAre(that: Any) = any { it isA that }

/**
* Minimal pure-Kotlin bit set for indices [0..n-1].
*/
class KBitSet(private val n: Int) {
// Each element of 'data' holds 64 bits, covering up to n bits total.
private val data = LongArray((n + 63) ushr 6)

fun set(index: Int) {
val word = index ushr 6
val bit = index and 63
data[word] = data[word] or (1L shl bit)
}

fun get(index: Int): Boolean {
val word = index ushr 6
val bit = index and 63
return (data[word] and (1L shl bit)) != 0L
}

fun clear() { data.fill(0L) }

infix fun or(other: KBitSet) {
for (i in data.indices) data[i] = data[i] or other.data[i]
}

infix fun and(other: KBitSet) {
for (i in data.indices) data[i] = data[i] and other.data[i]
}

fun toSet(): Set<Int> {
val result = mutableSetOf<Int>()
for (i in 0 until n) if (get(i)) result.add(i)
return result
}
}
132 changes: 132 additions & 0 deletions src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/AFSA.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package ai.hypergraph.kaliningraph.automata

import ai.hypergraph.kaliningraph.KBitSet
import ai.hypergraph.kaliningraph.parsing.Σᐩ

// Acyclic finite state automaton
class AFSA(override val Q: TSA, override val init: Set<Σᐩ>, override val final: Set<Σᐩ>): FSA(Q, init, final) {
fun topSort(): List<Σᐩ> {
// 1) Build adjacency lists (only next-states) from `transit`.
// We also need to track in-degrees of each state.
val adjacency = mutableMapOf<Σᐩ, MutableList<Σᐩ>>()
val inDegree = mutableMapOf<Σᐩ, Int>()

// Initialize adjacency and inDegree for all states
for (s in states) {
adjacency[s] = mutableListOf()
inDegree[s] = 0
}

// Fill adjacency and in-degree
for ((fromState, outEdges) in transit) {
// outEdges is a list of (symbol, toState) pairs
for ((_, toState) in outEdges) {
adjacency[fromState]!!.add(toState)
inDegree[toState] = inDegree[toState]!! + 1
}
}

// 2) Collect all states with in-degree 0 into a queue
val zeroQueue = ArrayDeque<Σᐩ>()
for ((st, deg) in inDegree) if (deg == 0) zeroQueue.add(st)

// 3) Repeatedly pop from queue, and decrement in-degree of successors
val result = mutableListOf<Σᐩ>()
while (zeroQueue.isNotEmpty()) {
val s = zeroQueue.removeFirst()
result.add(s)

for (next in adjacency[s]!!) {
val d = inDegree[next]!! - 1
inDegree[next] = d
if (d == 0) {
zeroQueue.add(next)
}
}
}

// 4) The 'result' is our topological ordering.
return result
}

// Since the FSA is acyclic, we can use a more efficient topological ordering
override val stateLst by lazy {
topSort()
// .also {
// if (it.size != states.size)
// throw Exception("Contained ${states.size} but ${it.size} topsorted indices:\n" +
// "T:${Q.joinToString("") { (a, b, c) -> ("($a -[$b]-> $c)") }}\n" +
// "V:${graph.vertices.map { it.label }.sorted().joinToString(",")}\n" +
// "Q:${Q.states().sorted().joinToString(",")}\n" +
// "S:${states.sorted().joinToString(",")}"
// )
// }
}

// Assumes stateLst is already in topological order:
override val allPairs: Map<Pair<Int, Int>, Set<Int>> by lazy {
val fwdAdj = Array(numStates) { mutableListOf<Int>() }
val revAdj = Array(numStates) { mutableListOf<Int>() }

for ((fromLabel, _, toLabel) in Q) {
val i = stateMap[fromLabel]!!
val j = stateMap[toLabel]!!
fwdAdj[i].add(j)
revAdj[j].add(i)
}

// 1) Prepare KBitSets for post[] and pre[]
val post = Array(numStates) { KBitSet(numStates) }
val pre = Array(numStates) { KBitSet(numStates) }

// 2) Compute post[i] in reverse topological order
for (i in (numStates - 1) downTo 0) {
post[i].set(i)
for (k in fwdAdj[i]) post[i].or(post[k])
}

// 3) Compute pre[i] in forward topological order
for (i in 0 until numStates) {
pre[i].set(i)
for (p in revAdj[i]) pre[i].or(pre[p])
}

// 4) Build allPairs by intersecting post[i] and pre[j]
// We can skip the intersection if j not reachable from i,
// i.e. if post[i].get(j) == false => empty set.
//
// We'll reuse a single KBitSet 'tmp' to avoid allocations:
val result = mutableMapOf<Pair<Int, Int>, Set<Int>>()

for (i in 0 until numStates) {
for (j in i until numStates) {
when {
i == j -> {
// The trivial path i->i has just i on it (assuming zero-length path is allowed).
// Or, if you prefer an empty path to have no “states in between,”
// you could make this emptySet().
result[i to i] = emptySet()
}
!post[i].get(j) -> {
// i < j, but j is not actually reachable from i
result[i to j] = emptySet()
// In a DAG, j->i is definitely unreachable if j > i, so:
result[j to i] = emptySet()
}
else -> {
// i < j and j is reachable from i => do the intersection of post[i] & pre[j].
val tmp = KBitSet(numStates)
tmp.or(post[i])
tmp.and(pre[j])
result[i to j] = tmp.toSet()

// j>i => definitely unreachable for j->i in a DAG
result[j to i] = emptySet()
}
}
}
}

result
}
}
20 changes: 4 additions & 16 deletions src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/FSA.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ai.hypergraph.kaliningraph.automata

import ai.hypergraph.kaliningraph.KBitSet
import ai.hypergraph.kaliningraph.graphs.LabeledGraph
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.kaliningraph.repair.MAX_RADIUS
Expand All @@ -17,21 +18,6 @@ fun Σᐩ.coords(): Pair<Int, Int> =
typealias STC = Triple<Int, Int, Int>
fun STC.coords() = π2 to π3

class ACYC_FSA constructor(override val Q: TSA, override val init: Set<Σᐩ>, override val final: Set<Σᐩ>): FSA(Q, init, final) {
// Since the FSA is acyclic, we can use a more efficient topological ordering
override val stateLst by lazy {
graph.topSort().map { it.label }.also {
if (it.size != states.size)
throw Exception("Contained ${states.size} but ${it.size} topsorted indices:\n" +
"T:${Q.joinToString("") { (a, b, c) -> ("($a -[$b]-> $c)") }}\n" +
"V:${graph.vertices.map { it.label }.sorted().joinToString(",")}\n" +
"Q:${Q.states().sorted().joinToString(",")}\n" +
"S:${states.sorted().joinToString(",")}"
)
}
}
}

// TODO: Add support for incrementally growing the FSA by adding new transitions
open class FSA constructor(open val Q: TSA, open val init: Set<Σᐩ>, open val final: Set<Σᐩ>) {
open val alphabet by lazy { Q.map { it.π2 }.toSet() }
Expand All @@ -41,6 +27,7 @@ open class FSA constructor(open val Q: TSA, open val init: Set<Σᐩ>, open val
val transit: Map<Σᐩ, List<Pair<Σᐩ, Σᐩ>>> by lazy {
Q.groupBy { it.π1 }.mapValues { (_, v) -> v.map { it.π2 to it.π3 } }
}

val revtransit: Map<Σᐩ, List<Pair<Σᐩ, Σᐩ>>> by lazy {
Q.groupBy { it.π3 }.mapValues { (_, v) -> v.map { it.π2 to it.π1 } }
}
Expand Down Expand Up @@ -188,7 +175,8 @@ open class FSA constructor(open val Q: TSA, open val init: Set<Σᐩ>, open val
val dp: Array<Array<Array<PTree?>>> = Array(nStates) { Array(nStates) { Array(width) { null } } }

// 2) Initialize terminal productions A -> a
for ((p, σ, q) in levFSA.allIndexedTxs1(cfg)) {
val aitx = levFSA.allIndexedTxs1(cfg)
for ((p, σ, q) in aitx) {
val Aidxs = bimap.TDEPS[σ]!!.map { bindex[it] }
for (Aidx in Aidxs) {
val newLeaf = PTree(root = "[$p~${bindex[Aidx]}~$q]", branches = PSingleton(σ))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ fun makeExactLevCFL(
val initialStates = setOf("q_" + pd(0, digits).let { "$it/$it" })
val finalStates = Q.states().filter { it.unpackCoordinates().let { (i, j) -> ((str.size - i + j).absoluteValue == radius) } }

ACYC_FSA(Q, initialStates, finalStates)
AFSA(Q, initialStates, finalStates)
.also { it.height = radius; it.width = str.size; it.levString = str }
.also { println("Levenshtein-${str.size}x$radius automaton had ${Q.size} arcs!") }
}
Expand Down Expand Up @@ -121,7 +121,7 @@ fun makeLevFSA(
val finalStates =
Q.states().filter { it.unpackCoordinates().let { (i, j) -> ((str.size - i + j).absoluteValue <= maxRad) } }

ACYC_FSA(Q, initialStates, finalStates)
AFSA(Q, initialStates, finalStates)
.also { it.height = maxRad; it.width = str.size; it.levString = str }
// .nominalize()
.also { println("Reduced L-NFA(${str.size}, $maxRad) from $initSize to ${Q.size} arcs in ${clock.elapsedNow()}") }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ data class Segmentation(
else -> cfg.parseInvalidWithMaximalFragments(line)
.map { it.span }.filter { 2 < (it.last - it.first) }.flatten()
.let { it to tokens.indices.filterNot { i -> i in it } }
.let { if (it.second.isEmpty() ) it.second to it.first else it }
}.let {
Segmentation(
valid = it.first,
Expand Down

0 comments on commit a8a8a1a

Please sign in to comment.