-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathhmm.py
98 lines (86 loc) · 2.45 KB
/
hmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
'''
STUDENT CODE:
'''
from random import uniform, random
class MarkovModel:
def __init__(self, probs=[], emissions=[]):
assert type(probs) is list and type(emissions) is list
assert len(probs) == len(emissions)
self.p = probs
self.e = emissions
self.indices = {}
'''
Run markov model until endfunc returns anything other than False
endfunc must be a function that takes a string (for last emission)
and an int representing the number of iterations run.
endfunc must return False to continue or anything else (True) to stop
'''
def run(self, endfunc):
c = 0
i = int(uniform(0, len(self.e)))
word = self.e[i]
while(endfunc(word, c) == False):
pass
# Your code here!
'''
PROVIDED CODE:
'''
def train(mm, data):
assert type(mm) is MarkovModel
l = 0
# Reset the hmm, just good practice.
mm.e = []
mm.p = []
mm.indices = {}
data = [(d.split(' ') + ['\n']) for d in data]
prev = None
for sent in data:
for word in sent:
if(mm.indices.get(word) == None):
mm.indices[word] = l
l += 1
mm.e.append(word)
mm.p = [([0] * len(mm.e)) for i in range(len(mm.e))]
for sent in data:
for word in sent:
if(mm.indices.get(prev) != None):
mm.p[mm.indices[prev]][mm.indices[word]] += 1
prev = word
# Divide everything by sum to get probabilities out of 1
for i in range(len(mm.p)):
denom = sum(mm.p[i])
for j in range(len(mm.p[i])):
# Plus-one smoothing, for variety
mm.p[i][j] = (mm.p[i][j] + 1) / (denom + len(mm.p[i]))
if(j > 1):
mm.p[i][j] += mm.p[j][j-1]
def endOnString(st, i, endstr="\n"):
if(endstr in st):
return True
else:
return False
def endAfterN(st, i, n=100):
if(i == n):
return True
else:
return False
# Parse Project Gutenberg csv
def dataFromNovel(filename):
f = open(filename)
data = []
for line in f:
# Remove extraneous lines
if(len(line) > 3):
# Add everything but the first char (") and the last 2 ("\n).
data.append(line[1:-2])
return data
def search(l, val, j=0):
assert type(l) is list
if(len(l) <= 1):
return j
else:
i = len(l) // 2
if(val < l[i]):
return search(l[:i], val, j)
else:
return search(l[i:], val, j+i)