-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
154 lines (105 loc) · 3.87 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import sys
import numpy as np
from numpy import array
from numpy import asarray
from numpy import zeros
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.models import Sequential
from keras.layers import Dense, LSTM, Activation
from keras.layers.convolutional import Conv1D
from keras.layers.convolutional import MaxPooling1D
from keras.layers import Flatten
from keras.layers import Embedding
from keras.optimizers import Adam
from keras.callbacks import LambdaCallback
from keras.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split
from keras.utils import to_categorical
# In[3]:
with open('data/all.txt','r') as quotefile:
quotes = quotefile.readlines()
if not os.path.exists('trained_weights'):
os.mkdir('trained_weights')
# In[4]:
t = Tokenizer(filters='')
t.fit_on_texts(quotes)
vocab_size = len(t.word_index) + 1
print(vocab_size)
# In[12]:
embedding_matrix = np.load('data/embedding_matrix.npy')
embedding_matrix.shape
index_word = np.load('data/index_word.npy')
index_word = index_word.item()
topics = ['death' , 'family', 'freedom' , 'funny', 'life' , 'love', 'happiness', 'science', 'success', 'politics']
# ## Do for all docs
for topic in topics:
# In[13]:
with open('data/%s.txt'%topic,'r') as funnyfile:
funnyquotes = funnyfile.readlines()
# In[14]:
encoded_docs = t.texts_to_sequences(funnyquotes)
funny_doc = encoded_docs[0]
# In[15]:
maxlen = 100
step = 1
seq_funny = []
next_seq_funny = []
quote_len_funny = len(funny_doc)
# In[16]:
for i in range(0, quote_len_funny - maxlen, step):
seq_funny.append(funny_doc[i: i + maxlen])
next_seq_funny.append(funny_doc[i + maxlen])
print('sequences:', len(seq_funny))
seq_funny = np.asarray(seq_funny)
next_seq_funny = np.asarray(next_seq_funny)
# # Text Generation using Word Embeddings
# In[20]:
X = seq_funny
y = to_categorical(next_seq_funny, num_classes=vocab_size)
# In[26]:
model = Sequential()
e = Embedding( vocab_size, 100, weights=[embedding_matrix], trainable=True)
model.add(e)
# In[29]:
model.add(LSTM(100))
model.add(Dense(y.shape[1]))
# In[30]:
model.add(Activation('softmax'))
optimizer = Adam(lr=0.001, beta_1=0.9, beta_2=0.999)
model.compile(
loss='categorical_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
print(model.summary())
# In[32]:
def on_epoch_end(epoch, logs):
print()
print('----- Generating text after Epoch: %d' % epoch)
start_index = np.random.randint(0, len(funny_doc) - maxlen - 1)
sentence = funny_doc[start_index: start_index + maxlen]
predicted = ''
original_sentence = ''.join([str(index_word[word])+' ' for word in sentence])
print('----- Input seed: %s'%original_sentence)
print('----- Output: ')
for i in range(maxlen):
x_pred = np.reshape(sentence,(1, -1))
preds = model.predict(x_pred, verbose=0)
preds = preds[0]
next_index = np.argmax(preds)
next_char = index_word[next_index]
sentence = np.append(sentence, next_index)
predicted = predicted + next_char + ' '
sys.stdout.write(next_char)
sys.stdout.write(' ')
sys.stdout.flush()
sys.stdout.write("\n")
sys.stdout.write("-----\n")
#print('----- Output: %s')%predicted
print_callback = LambdaCallback(on_epoch_end=on_epoch_end)
# In[33]:
filepath="trained_weights/QG-%s-{epoch:02d}-{loss:.4f}-{acc:.4f}.hdf5"%topic
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')
# In[ ]:
model.fit(X, y, epochs=30, batch_size=24, callbacks=[checkpoint, print_callback])