-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgentxt.py
99 lines (76 loc) · 3.13 KB
/
gentxt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python
"""Example to generate text from a recurrent neural network language model.
This code is ported from following implementation.
https://github.com/longjie/chainer-char-rnn/blob/master/sample.py
"""
import argparse
import sys
import numpy as np
import six
import json
import chainer
from chainer import cuda
import chainer.functions as F
import chainer.links as L
from chainer import serializers
import train_country_bot
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', '-m', type=str, required=True,
help='model data, saved by train_ptb.py')
parser.add_argument('--primetext', '-p', type=str, required=True,
default='',
help='base text data, used for text generation')
parser.add_argument('--seed', '-s', type=int, default=123,
help='random seeds for text generation')
parser.add_argument('--unit', '-u', type=int, default=650,
help='number of units')
parser.add_argument('--sample', type=int, default=1,
help='negative value indicates NOT use random choice')
parser.add_argument('--length', type=int, default=20,
help='length of the generated text')
parser.add_argument('--gpu', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
args = parser.parse_args()
np.random.seed(args.seed)
xp = cuda.cupy if args.gpu >= 0 else np
# load vocabulary
vocab = json.load(open("vocab_indexes.json"))
ivocab = {}
for c, i in vocab.items():
ivocab[i] = c
# should be same as n_units , described in train_ptb.py
n_units = args.unit
lm = train_country_bot.RNNForLM(len(vocab), n_units, train=False)
model = L.Classifier(lm)
serializers.load_npz(args.model, model)
if args.gpu >= 0:
cuda.get_device(args.gpu).use()
model.to_gpu()
model.predictor.reset_state()
primetext = args.primetext
if isinstance(primetext, six.binary_type):
primetext = primetext.decode('utf-8')
if primetext in vocab:
prev_word = chainer.Variable(xp.array([vocab[primetext]], xp.int32))
else:
print('ERROR: Unfortunately ' + primetext + ' is unknown.')
exit()
prob = F.softmax(model.predictor(prev_word))
sys.stdout.write(primetext + ' ')
for i in six.moves.range(args.length):
prob = F.softmax(model.predictor(prev_word))
if args.sample > 0:
probability = cuda.to_cpu(prob.data)[0].astype(np.float64)
probability /= np.sum(probability)
index = np.random.choice(range(len(probability)), p=probability)
else:
index = np.argmax(cuda.to_cpu(prob.data))
if ivocab[index] == '<eos>':
sys.stdout.write('. ')
else:
sys.stdout.write(ivocab[index] + ' ')
prev_word = chainer.Variable(xp.array([index], dtype=xp.int32))
sys.stdout.write('\n')
if __name__ == '__main__':
main()