Skip to content

Commit

Permalink
refactor and setup probabilistic LBH test
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Dec 29, 2023
1 parent fb6f5da commit 689e18b
Show file tree
Hide file tree
Showing 19 changed files with 5,913 additions and 91 deletions.
80 changes: 5 additions & 75 deletions src/commonMain/kotlin/ai/hypergraph/kaliningraph/StringUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ import ai.hypergraph.kaliningraph.tensor.transpose
import ai.hypergraph.kaliningraph.types.*
import kotlin.math.*

fun Σᐩ.tokenizeByWhitespace(): List<Σᐩ> = split(Regex("\\s+")).filter { it.isNotBlank() }

fun Σᐩ.tokenizeByWhitespaceAndKeepDelimiters(): List<Σᐩ> =
split(Regex("(?<=\\s)|(?=\\s)"))

infix fun Char.closes(that: Char) =
if (this == ')' && that == '(') true
else if (this == ']' && that == '[') true
Expand Down Expand Up @@ -68,81 +73,6 @@ fun Σᐩ.carveSeams(toRemove: Regex = Regex("\\s{2,}")): Σᐩ =
}
}

fun allPairsLevenshtein(s1: Set<Σᐩ>, s2: Set<Σᐩ>) =
(s1 * s2).sumOf { (a, b) -> levenshtein(a, b) }

fun levenshtein(s1: Σᐩ, s2: Σᐩ): Int =
levenshtein(s1.tokenizeByWhitespace().toList(), s2.tokenizeByWhitespace().toList())

fun <T> levenshtein(o1: List<T>, o2: List<T>): Int {
var prev = IntArray(o2.size + 1)
for (j in 0 until o2.size + 1) prev[j] = j
for (i in 1 until o1.size + 1) {
val curr = IntArray(o2.size + 1)
curr[0] = i
for (j in 1 until o2.size + 1) {
val d1 = prev[j] + 1
val d2 = curr[j - 1] + 1
val d3 = prev[j - 1] + if (o1[i - 1] == o2[j - 1]) 0 else 1
curr[j] = min(min(d1, d2), d3)
}

prev = curr
}
return prev[o2.size]
}

fun levenshteinAlign(a: String, b: String) =
levenshteinAlign(a.tokenizeByWhitespace(), b.tokenizeByWhitespace())

fun <T> levenshteinAlign(a: List<T>, b: List<T>): List<Pair<T?, T?>> {
val costs = Array(a.size + 1) { IntArray(b.size + 1) }
for (j in 0..b.size) costs[0][j] = j
for (i in 1..a.size) {
costs[i][0] = i
for (j in 1..b.size) {
val temp = costs[i - 1][j - 1] + (if (a[i - 1] == b[j - 1]) 0 else 1)
costs[i][j] = minOf(1 + minOf(costs[i - 1][j], costs[i][j - 1]), temp)
}
}

val aPathRev = mutableListOf<T?>()
val bPathRev = mutableListOf<T?>()
var i = a.size
var j = b.size
while (i > 0 && j > 0) {
val temp = costs[i - 1][j - 1] + (if (a[i - 1] == b[j - 1]) 0 else 1)
when (costs[i][j]) {
temp -> {
aPathRev.add(a[--i])
bPathRev.add(b[--j])
}
1 + costs[i-1][j] -> {
aPathRev.add(a[--i])
bPathRev.add(null)
}
1 + costs[i][j-1] -> {
aPathRev.add(null)
bPathRev.add(b[--j])
}
}
}

while (i > 0) {
aPathRev.add(a[--i])
bPathRev.add(null)
}

while (j > 0) {
aPathRev.add(null)
bPathRev.add(b[--j])
}

val revPathA = aPathRev.reversed()
val revPathB = bPathRev.reversed()
return revPathA.zip(revPathB)
}

fun <T> List<Pair<T?, T?>>.paintDiffs(): String =
joinToString(" ") { (a, b) ->
when {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package ai.hypergraph.kaliningraph.parsing
package ai.hypergraph.kaliningraph.automata

import ai.hypergraph.kaliningraph.graphs.*
import ai.hypergraph.kaliningraph.parsing.Σᐩ
import ai.hypergraph.kaliningraph.tokenizeByWhitespace
import ai.hypergraph.kaliningraph.types.*

typealias Arc = Π3A<Σᐩ>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.automata.FSA
import ai.hypergraph.kaliningraph.types.*
import kotlin.time.TimeSource

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.graphs.LabeledGraph
import ai.hypergraph.kaliningraph.sampling.choose
import ai.hypergraph.kaliningraph.tokenizeByWhitespace
import ai.hypergraph.kaliningraph.types.*
import kotlin.jvm.JvmName
import kotlin.time.*
Expand Down Expand Up @@ -63,6 +64,9 @@ val CFG.normalForm: CFG by cache { normalize() }
val CFG.depGraph: LabeledGraph by cache { dependencyGraph() }
val CFG.revDepGraph: LabeledGraph by cache { revDependencyGraph() }

// Terminals which are blocked from being synthesized by a solver
val CFG.blocked: MutableSet<Σᐩ> by cache { mutableSetOf() }

val CFG.originalForm: CFG by cache { rewriteHistory[this]?.get(0) ?: this }
val CFG.nonparametricForm: CFG by cache { rewriteHistory[this]!![1] }
//val CFG.originalForm by cache { rewriteHistory[this]!![0] }
Expand Down Expand Up @@ -239,8 +243,7 @@ class BiMap(cfg: CFG) {
.map { it.value.map { v -> v to it.key[0] to it.key[1] } }.flatten()
}
val X2WZ: Map<Σᐩ, List<Triple<Σᐩ, Σᐩ, Σᐩ>>> by lazy {
TRIPL.groupBy { it.second }
.mapValues { it.value.map { it.first to it.second to it.third } }
TRIPL.groupBy { it.second }.mapValues { it.value }
}
val UNITS by lazy {
cfg.filter { it.RHS.size == 1 && it.RHS[0] !in cfg.nonterminals }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.automata.*
import ai.hypergraph.kaliningraph.tokenizeByWhitespace
import ai.hypergraph.kaliningraph.types.*
import kotlin.math.absoluteValue
import kotlin.math.*

// Only accept states that are within radius dist of (strLen, 0)
fun acceptStates(strLen: Int, dist: Int) =
Expand Down Expand Up @@ -127,4 +129,79 @@ fun List<Π5<Int, Int, Σᐩ, Int, Int>>.postProc(digits: Int) =
pd(a, digits) to pd(b, digits) to s to pd(d, digits) to pd(e, digits)
}.map { (a, b, s, d, e) ->
"q_$a/$b" to s to "q_$d/$e"
}.toSet()
}.toSet()

fun allPairsLevenshtein(s1: Set<Σᐩ>, s2: Set<Σᐩ>) =
(s1 * s2).sumOf { (a, b) -> levenshtein(a, b) }

fun levenshtein(s1: Σᐩ, s2: Σᐩ): Int =
levenshtein(s1.tokenizeByWhitespace().toList(), s2.tokenizeByWhitespace().toList())

fun <T> levenshtein(o1: List<T>, o2: List<T>): Int {
var prev = IntArray(o2.size + 1)
for (j in 0 until o2.size + 1) prev[j] = j
for (i in 1 until o1.size + 1) {
val curr = IntArray(o2.size + 1)
curr[0] = i
for (j in 1 until o2.size + 1) {
val d1 = prev[j] + 1
val d2 = curr[j - 1] + 1
val d3 = prev[j - 1] + if (o1[i - 1] == o2[j - 1]) 0 else 1
curr[j] = min(min(d1, d2), d3)
}

prev = curr
}
return prev[o2.size]
}

fun levenshteinAlign(a: String, b: String) =
levenshteinAlign(a.tokenizeByWhitespace(), b.tokenizeByWhitespace())

fun <T> levenshteinAlign(a: List<T>, b: List<T>): List<Pair<T?, T?>> {
val costs = Array(a.size + 1) { IntArray(b.size + 1) }
for (j in 0..b.size) costs[0][j] = j
for (i in 1..a.size) {
costs[i][0] = i
for (j in 1..b.size) {
val temp = costs[i - 1][j - 1] + (if (a[i - 1] == b[j - 1]) 0 else 1)
costs[i][j] = minOf(1 + minOf(costs[i - 1][j], costs[i][j - 1]), temp)
}
}

val aPathRev = mutableListOf<T?>()
val bPathRev = mutableListOf<T?>()
var i = a.size
var j = b.size
while (i > 0 && j > 0) {
val temp = costs[i - 1][j - 1] + (if (a[i - 1] == b[j - 1]) 0 else 1)
when (costs[i][j]) {
temp -> {
aPathRev.add(a[--i])
bPathRev.add(b[--j])
}
1 + costs[i-1][j] -> {
aPathRev.add(a[--i])
bPathRev.add(null)
}
1 + costs[i][j-1] -> {
aPathRev.add(null)
bPathRev.add(b[--j])
}
}
}

while (i > 0) {
aPathRev.add(a[--i])
bPathRev.add(null)
}

while (j > 0) {
aPathRev.add(null)
bPathRev.add(b[--j])
}

val revPathA = aPathRev.reversed()
val revPathB = bPathRev.reversed()
return revPathA.zip(revPathB)
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.repair.substituteIndices
import ai.hypergraph.kaliningraph.sampling.choose
import ai.hypergraph.kaliningraph.tensor.UTMatrix
import ai.hypergraph.kaliningraph.types.*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.levenshtein
import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.tensor.UTMatrix
import ai.hypergraph.kaliningraph.types.*

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.hypergraph.kaliningraph.parsing
package ai.hypergraph.kaliningraph.repair

import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.kaliningraph.sampling.choose
import ai.hypergraph.kaliningraph.types.powerset
import kotlin.math.absoluteValue
Expand All @@ -18,9 +19,6 @@ typealias Reconstructor = MutableList<Π2A<Σᐩ>>
// Takes a string and a set of invariant indices and returns mutated strings
typealias Mutator = (Σᐩ, Set<Int>) -> Sequence<Σᐩ>

// Terminals which are blocked from being synthesized by a solver
val CFG.blocked: MutableSet<Σᐩ> by cache { mutableSetOf() }

fun repair(
prompt: Σᐩ,
cfg: CFG,
Expand Down Expand Up @@ -355,11 +353,6 @@ fun List<Σᐩ>.substituteIndices(idxs: Set<Int>, sub: (Σᐩ, Int) -> Σᐩ): L
private fun List<Σᐩ>.substitute(idxs: Set<Int>, sub: (Σᐩ, Int) -> Σᐩ): Σᐩ =
substituteIndices(idxs, sub).joinToString(" ").trim()

fun Σᐩ.tokenizeByWhitespace(): List<Σᐩ> = split(Regex("\\s+")).filter { it.isNotBlank() }

fun Σᐩ.tokenizeByWhitespaceAndKeepDelimiters(): List<Σᐩ> =
split(Regex("(?<=\\s)|(?=\\s)"))

// MUCH faster than above (but incorrect)
//fun Σᐩ.tokenizeByWhitespace(): List<Σᐩ> =
// mutableListOf<Σᐩ>().also { list ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ai.hypergraph.kaliningraph.parsing

import Grammars
import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.automata.parseFSA
import ai.hypergraph.kaliningraph.sampling.all
import kotlin.test.*
import kotlin.time.*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.repair.multiTokenSubstitutionsAndInsertions
import kotlin.test.Test
import kotlin.time.*

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ai.hypergraph.kaliningraph.*
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.kaliningraph.repair.*
import ai.hypergraph.kaliningraph.sampling.*
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicInteger
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package ai.hypergraph.kaliningraph.repair

import java.io.File
import kotlin.math.pow

enum class EditType { INS, DEL, SUB }
data class ContextEdit(val type: EditType, val context: Context, val newMid: String) {
override fun toString(): String = context.run {
"$type, (( " + when (type) {
EditType.INS -> "$left [${newMid}] $right"
EditType.DEL -> "$left ~${mid}~ $right"
EditType.SUB -> "$left [${mid} -> ${newMid}] $right"
} + " // " + when (type) {
EditType.INS -> "$left [${newMid}] $right"
EditType.DEL -> "$left ~${mid}~ $right"
EditType.SUB -> "$left [${mid} -> ${newMid}] $right"
} + " ))"
}
}
data class CEAProb(val cea: ContextEdit, val idx: Int, val frequency: Int) {
override fun equals(other: Any?): Boolean = when (other) {
is CEAProb -> cea == other.cea && idx == other.idx
else -> false
}
override fun hashCode(): Int = 31 * cea.hashCode() + idx
override fun toString(): String = "[[ $cea, $idx, $frequency ]]"
}
data class Context(val left: String, val mid: String, val right: String) {
override fun equals(other: Any?) = when (other) {
is Context -> left == other.left && mid == other.mid && right == other.right
else -> false
}

override fun hashCode(): Int {
var result = left.hashCode()
result = 31 * result + mid.hashCode()
result = 31 * result + right.hashCode()
return result
}
}

data class CEADist(val allProbs: Map<ContextEdit, Int>) {
val P_delSub = allProbs.filter { it.key.type != EditType.INS }
val P_insert = allProbs.filter { it.key.type == EditType.INS }
val P_delSubOnCtx = P_delSub.keys.groupBy { it.context }
val P_insertOnCtx = P_insert.keys.groupBy { it.context }
}

// Divesity: lower is more diverse, higher is less diverse, 1.0 is natural frequencies
fun File.readTrigramStats(diversity: Double = 1.0): CEADist =
readLines().drop(1).map { it.split(", ") }.associate {
ContextEdit(
type = EditType.valueOf(it[0].trim()),
context = Context(it[1], it[2], it[3]),
newMid = it[4]
) to it[5].trim().toDouble().pow(diversity).toInt().coerceAtLeast(1)
}.let { CEADist(it) }
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.hypergraph.kaliningraph.sat

import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.kaliningraph.repair.*
import ai.hypergraph.kaliningraph.tensor.*
import ai.hypergraph.kaliningraph.types.*
import org.logicng.formulas.*
Expand Down
Loading

0 comments on commit 689e18b

Please sign in to comment.