-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_tester.py
executable file
·106 lines (87 loc) · 3.16 KB
/
model_tester.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python2.7
import click
import logging
import pandas as pd
import pomegranate as pg
import ujson as json
from tqdm import tqdm
from itertools import product, repeat, izip
from multiprocessing import Pool
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s : %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
logger = logging.getLogger()
def get_answer(sentence, model):
if any(len(w["v"]) > 1 and len(w["a"]) == 0 for w in sentence["words"]):
logger.debug("Sentence %d has unresolved words, skip", sentence["id"])
return
result = []
word_ids = []
parts = []
current_part = []
unambig_count = 0
for w in sentence["words"]:
result.append(w["v"][0] if len(w["v"]) == 1 else w["a"][0])
word_ids.append(w["id"])
if len(w["v"]) > 1:
unambig_count += 1
current_part.append(w)
else:
if len(current_part) > 0:
parts.append(current_part)
current_part = []
if len(current_part) > 0:
parts.append(current_part)
result = pd.DataFrame(result, index=word_ids)
answers = []
for part in parts:
word_ids = [w["id"] for w in part]
temp = result.copy()
max_prob, max_answer = None, None
for answer in product(*[w["v"] for w in part]):
answer = pd.DataFrame(list(answer), index=word_ids)
temp.loc[answer.index, :] = answer
prob = model.log_probability(temp.values)
if max_prob is None or prob > max_prob:
max_prob = prob
max_answer = answer.copy()
answers.append(max_answer)
temp = result.copy()
for answer in answers:
temp.loc[answer.index,:] = answer
return result != temp
def process(args):
(line, model) = args
sentence = json.loads(line)
errors = get_answer(sentence, model)
if errors is None:
return
results = dict()
results["words"] = errors.shape[0]
results["errors_by_tag"] = list(errors.sum(axis=0).values)
results["errors"] = errors.any(axis=1).sum()
sent_unambig = []
sent_unambig_all = 0
for word in sentence["words"]:
if len(word["v"]) > 1:
sent_unambig_all += 1
variants = pd.DataFrame(word["v"])
sent_unambig.append(variants.nunique() > 1)
if sent_unambig_all > 0:
results["unambig"] = sent_unambig_all
results["unambig_by_tag"] = list((pd.concat(sent_unambig, axis=1)).sum(axis=1).values)
results["id"] = sentence["id"]
results["text"] = sentence["text"]
return results
@click.command()
@click.argument("sentences", type=click.File("rt", encoding="utf8"))
@click.argument("model", type=click.File("rt", encoding="utf8"))
@click.argument("output", type=click.File("wt"))
def main(sentences, model, output):
model = pg.HiddenMarkovModel.from_json(model.read())
pool = Pool(processes=5)
for result in tqdm(pool.imap_unordered(process, izip(sentences, repeat(model))), total=114883):
if result is not None:
output.write(json.dumps(result))
output.write("\n")
output.close()
if __name__ == "__main__":
main()