forked from xiaojunxu/SQLNet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
extract_vocab.py
65 lines (56 loc) · 1.74 KB
/
extract_vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import json
import torch
from sqlnet.utils import *
import numpy as np
import datetime
LOCAL_TEST=False
if LOCAL_TEST:
N_word=100
B_word=6
USE_SMALL=True
else:
N_word=300
B_word=42
USE_SMALL=False
sql_data, table_data, val_sql_data, val_table_data,\
test_sql_data, test_table_data, TRAIN_DB, DEV_DB, TEST_DB = \
load_dataset(0, use_small=USE_SMALL)
word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word),
use_small=USE_SMALL)
print "Length of word vocabulary: %d"%len(word_emb)
word_to_idx = {'<UNK>':0, '<BEG>':1, '<END>':2}
word_num = 3
embs = [np.zeros(N_word,dtype=np.float32) for _ in range(word_num)]
def check_and_add(tok):
#Check if the tok is in the vocab. If not, add it.
global word_num
if tok not in word_to_idx and tok in word_emb:
word_to_idx[tok] = word_num
word_num += 1
embs.append(word_emb[tok])
for sql in sql_data:
for tok in sql['question_tok']:
check_and_add(tok)
for tab in table_data.values():
for col in tab['header_tok']:
for tok in col:
check_and_add(tok)
for sql in val_sql_data:
for tok in sql['question_tok']:
check_and_add(tok)
for tab in val_table_data.values():
for col in tab['header_tok']:
for tok in col:
check_and_add(tok)
for sql in test_sql_data:
for tok in sql['question_tok']:
check_and_add(tok)
for tab in test_table_data.values():
for col in tab['header_tok']:
for tok in col:
check_and_add(tok)
print "Length of used word vocab: %s"%len(word_to_idx)
emb_array = np.stack(embs, axis=0)
with open('glove/word2idx.json', 'w') as outf:
json.dump(word_to_idx, outf)
np.save(open('glove/usedwordemb.npy', 'w'), emb_array)