-
Notifications
You must be signed in to change notification settings - Fork 4
/
attribute_TIL.py
194 lines (148 loc) · 5.26 KB
/
attribute_TIL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# -*- coding: utf-8 -*-
import glob, json, os, time, codecs
import urlparse, string, collections
import urllib2, itertools
from unidecode import unidecode
from parsing import tokenize, frequency_table
os.system("mkdir -p db")
os.system("mkdir -p db/model")
PARALLEL_FLAG = True
#PARALLEL_FLAG = False
def reddit_json():
F_REDDIT = sorted(glob.glob("data/reddit/*.json"))[::-1]
for f_json in F_REDDIT:
with codecs.open(f_json,'r','utf-8') as FIN:
js = json.load(FIN)
js["filename"] = f_json
f_wiki = "data/wikipedia/{}.txt".format(js["name"])
with open(f_wiki) as FIN:
js["wiki"]=FIN.read()
yield js
def split_by_sections(text):
# Works on mediawiki documents
text = text.decode('utf-8')
data = {}
name = u""
buffer = []
for line in text.split(u'\n'):
if line[:2] == "==":
data[name] = u'\n'.join(buffer)
name = line.strip().strip("=").strip().lower()
#name = name.encode('utf-8','replace')
buffer = []
else:
buffer.append(line)
# Add the last section
data[name] = u'\n'.join(buffer)
bad_sections = ["see also","references","external links"]
for key in bad_sections:
if key in data:
data.pop(key)
return data
def split_by_paragraph(full_text):
sections = split_by_sections(full_text)
for section,text in sections.items():
for paragraph in text.split('\n'):
paragraph = paragraph.strip()
if paragraph:
yield paragraph
def find_TIL_match(js):
text = js["wiki"]
url = js["url"]
wiki_title = url.split('/')[-1].replace('_',' ').lower()
wiki_title = wiki_title.split('#')[0]
wiki_title = unidecode(urllib2.unquote(wiki_title))
wiki_title_tokens = set(wiki_title.split())
TIL_text = js["title"]
TIL_tokens = set(tokenize(TIL_text))
# Remove special tokens from TIL
TIL_tokens = [x for x in TIL_tokens if len(x)>2 and "TOKEN_" not in x]
paragraphs = list(split_by_paragraph(text))
tokens = map(tokenize,paragraphs)
freq = frequency_table(tokens)
# Find words in TIL used in text
matching_columns = list(set(freq.columns).intersection(TIL_tokens))
# Idea: Score each paragraph with the highest ranked match
df = freq[matching_columns]
# Row normalize, thus unique words count for more!
df /= df.sum(axis=0)
df.fillna(0,inplace=True)
# Find the top scoring paragraph
score = df.sum(axis=1)
top_idx = score.argmax()
match_text = paragraphs[top_idx]
# Now, normalize off the full frequency table for the entropy weight
freq /= freq.sum(axis=0)
freq.fillna(0,inplace=True)
tokens = list(freq.columns[freq.ix[top_idx]>0])
weights = freq[tokens].ix[top_idx]
# Convert them into SQL-able formats
w_str='[{}]'.format(','.join(map("{:0.2f}".format, weights)))
d_out = {
"reddit_idx" : js["name"],
"TIL" : TIL_text,
"unprocessed_wikitext" : match_text,
"tokens" : ' '.join(tokens),
"url" : url,
"score" : js["score"],
"weights" : w_str
}
key_order = ["reddit_idx", "TIL",
"unprocessed_wikitext", "tokens",
"url", "score", "weights"]
data_match = [d_out[key] for key in key_order]
# Save the remaining parargraphs
data_unmatch = []
for n in range(len(paragraphs)):
if n != top_idx:
tokens = list(freq.columns[freq.ix[n]>0])
weights = freq[tokens].ix[n]
assert(len(tokens)==len(weights))
if len(tokens)>3:
# Convert them into SQL-able formats
w_str='[{}]'.format(','.join(map("{:0.2f}".format, weights)))
t_str = ' '.join(tokens)
data_unmatch.append( [t_str, w_str] )
return data_match, data_unmatch
def data_iterator():
ITR = itertools.imap(find_TIL_match, reddit_json())
if PARALLEL_FLAG:
import multiprocessing
P = multiprocessing.Pool()
ITR = P.imap(find_TIL_match, reddit_json())
for k,(data_match,data_unmatch) in enumerate(ITR):
print k, data_match[1]
yield data_match, data_unmatch
import sqlite3
conn = sqlite3.connect("db/training.db")
cmd_template = '''
CREATE TABLE IF NOT EXISTS training (
training_idx INTEGER PRIMARY KEY AUTOINCREMENT,
reddit_idx STRING,
TIL STRING,
unprocessed_wikitext STRING,
tokens STRING,
weights STRING, -- awful way to do it, but easy enough!
url STRING,
score INTEGER
);
CREATE TABLE IF NOT EXISTS false_positives (
idx INTEGER PRIMARY KEY AUTOINCREMENT,
tokens STRING,
weights STRING -- awful way to do it, but easy enough!
);
'''
conn.executescript(cmd_template)
cmd_insert_match = u'''
INSERT INTO training (reddit_idx, TIL, unprocessed_wikitext,
tokens, url, score, weights)
VALUES (?,?,?,?,?,?,?)
'''
cmd_insert_unmatch = u'''
INSERT INTO false_positives (tokens, weights) VALUES (?,?)
'''
ITR = data_iterator()
for data_match, data_unmatch in ITR:
conn.execute(cmd_insert_match , data_match)
conn.executemany(cmd_insert_unmatch, data_unmatch)
conn.commit()