-
Notifications
You must be signed in to change notification settings - Fork 153
/
gib_detect_train.py
75 lines (57 loc) · 2.66 KB
/
gib_detect_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#!/usr/bin/python
import math
import pickle
accepted_chars = 'abcdefghijklmnopqrstuvwxyz '
pos = dict([(char, idx) for idx, char in enumerate(accepted_chars)])
def normalize(line):
""" Return only the subset of chars from accepted_chars.
This helps keep the model relatively small by ignoring punctuation,
infrequenty symbols, etc. """
return [c.lower() for c in line if c.lower() in accepted_chars]
def ngram(n, l):
""" Return all n grams from l after normalizing """
filtered = normalize(l)
for start in range(0, len(filtered) - n + 1):
yield ''.join(filtered[start:start + n])
def train():
""" Write a simple model as a pickle file """
k = len(accepted_chars)
# Assume we have seen 10 of each character pair. This acts as a kind of
# prior or smoothing factor. This way, if we see a character transition
# live that we've never observed in the past, we won't assume the entire
# string has 0 probability.
counts = [[10 for i in xrange(k)] for i in xrange(k)]
# Count transitions from big text file, taken
# from http://norvig.com/spell-correct.html
for line in open('big.txt'):
for a, b in ngram(2, line):
counts[pos[a]][pos[b]] += 1
# Normalize the counts so that they become log probabilities.
# We use log probabilities rather than straight probabilities to avoid
# numeric underflow issues with long texts.
# This contains a justification:
# http://squarecog.wordpress.com/2009/01/10/dealing-with-underflow-in-joint-probability-calculations/
for i, row in enumerate(counts):
s = float(sum(row))
for j in xrange(len(row)):
row[j] = math.log(row[j] / s)
# Find the probability of generating a few arbitrarily choosen good and
# bad phrases.
good_probs = [avg_transition_prob(l, counts) for l in open('good.txt')]
bad_probs = [avg_transition_prob(l, counts) for l in open('bad.txt')]
# Assert that we actually are capable of detecting the junk.
assert min(good_probs) > max(bad_probs)
# And pick a threshold halfway between the worst good and best bad inputs.
thresh = (min(good_probs) + max(bad_probs)) / 2
pickle.dump({'mat': counts, 'thresh': thresh}, open('gib_model.pki', 'wb'))
def avg_transition_prob(l, log_prob_mat):
""" Return the average transition prob from l through log_prob_mat. """
log_prob = 0.0
transition_ct = 0
for a, b in ngram(2, l):
log_prob += log_prob_mat[pos[a]][pos[b]]
transition_ct += 1
# The exponentiation translates from log probs to probs.
return math.exp(log_prob / (transition_ct or 1))
if __name__ == '__main__':
train()