-
Notifications
You must be signed in to change notification settings - Fork 7
/
doc2vec.py
90 lines (81 loc) · 2.98 KB
/
doc2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import MeCab
import collections
from gensim import models
from gensim.models.doc2vec import LabeledSentence
INPUT_DOC_DIR = './aozora/'
OUTPUT_MODEL = 'doc2vec.model'
PASSING_PRECISION = 93
# 全てのファイルのリストを取得
def get_all_files(directory):
for root, dirs, files in os.walk(directory):
for file in files:
yield os.path.join(root, file)
# ファイルから文章を返す
def read_document(path):
with open(path, 'r', encoding='sjis', errors='ignore') as f:
return f.read()
# 青空文庫ファイルから作品部分のみ抜き出す
def trim_doc(doc):
lines = doc.splitlines()
valid_lines = []
is_valid = False
horizontal_rule_cnt = 0
break_cnt = 0
for line in lines:
if horizontal_rule_cnt < 2 and '-----' in line:
horizontal_rule_cnt += 1
is_valid = horizontal_rule_cnt == 2
continue
if not(is_valid):
continue
if line == '':
break_cnt += 1
is_valid = break_cnt != 3
continue
break_cnt = 0
valid_lines.append(line)
return ''.join(valid_lines)
# 文章から単語に分解して返す
def split_into_words(doc, name=''):
mecab = MeCab.Tagger("-Ochasen")
valid_doc = trim_doc(doc)
lines = mecab.parse(doc).splitlines()
words = []
for line in lines:
chunks = line.split('\t')
if len(chunks) > 3 and (chunks[3].startswith('動詞') or chunks[3].startswith('形容詞') or (chunks[3].startswith('名詞') and not chunks[3].startswith('名詞-数'))):
words.append(chunks[0])
return LabeledSentence(words=words, tags=[name])
# ファイルから単語のリストを取得
def corpus_to_sentences(corpus):
docs = [read_document(x) for x in corpus]
for idx, (doc, name) in enumerate(zip(docs, corpus)):
sys.stdout.write('\r前処理中 {} / {}'.format(idx, len(corpus)))
yield split_into_words(doc, name)
# 学習
def train(sentences):
model = models.Doc2Vec(size=400, alpha=0.0015, sample=1e-4, min_count=1, workers=4)
model.build_vocab(sentences)
for x in range(30):
print(x)
model.train(sentences)
ranks = []
for doc_id in range(100):
inferred_vector = model.infer_vector(sentences[doc_id].words)
sims = model.docvecs.most_similar([inferred_vector], topn=len(model.docvecs))
rank = [docid for docid, sim in sims].index(sentences[doc_id].tags[0])
ranks.append(rank)
print(collections.Counter(ranks))
if collections.Counter(ranks)[0] >= PASSING_PRECISION:
break
return model
if __name__ == '__main__':
corpus = list(get_all_files(INPUT_DOC_DIR))
sentences = list(corpus_to_sentences(corpus))
print()
model = train(sentences)
model.save(OUTPUT_MODEL)