forked from ZiJianZhao/SeqGAN-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerator.py
90 lines (80 loc) · 2.85 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding: utf-8 -*-
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class Generator(nn.Module):
"""Generator """
def __init__(self, num_emb, emb_dim, hidden_dim, use_cuda):
super(Generator, self).__init__()
self.num_emb = num_emb
self.emb_dim = emb_dim
self.hidden_dim = hidden_dim
self.use_cuda = use_cuda
self.emb = nn.Embedding(num_emb, emb_dim)
self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
self.lin = nn.Linear(hidden_dim, num_emb)
self.softmax = nn.LogSoftmax()
self.init_params()
def forward(self, x):
"""
Args:
x: (batch_size, seq_len), sequence of tokens generated by generator
"""
emb = self.emb(x)
h0, c0 = self.init_hidden(x.size(0))
output, (h, c) = self.lstm(emb, (h0, c0))
pred = self.softmax(self.lin(output.contiguous().view(-1, self.hidden_dim)))
return pred
def step(self, x, h, c):
"""
Args:
x: (batch_size, 1), sequence of tokens generated by generator
h: (1, batch_size, hidden_dim), lstm hidden state
c: (1, batch_size, hidden_dim), lstm cell state
"""
emb = self.emb(x)
output, (h, c) = self.lstm(emb, (h, c))
pred = F.softmax(self.lin(output.view(-1, self.hidden_dim)), dim=1)
return pred, h, c
def init_hidden(self, batch_size):
h = Variable(torch.zeros((1, batch_size, self.hidden_dim)))
c = Variable(torch.zeros((1, batch_size, self.hidden_dim)))
if self.use_cuda:
h, c = h.cuda(), c.cuda()
return h, c
def init_params(self):
for param in self.parameters():
param.data.uniform_(-0.05, 0.05)
def sample(self, batch_size, seq_len, x=None):
res = []
flag = False # whether sample from zero
if x is None:
flag = True
if flag:
x = Variable(torch.zeros((batch_size, 1)).long())
if self.use_cuda:
x = x.cuda()
h, c = self.init_hidden(batch_size)
samples = []
if flag:
for i in range(seq_len):
output, h, c = self.step(x, h, c)
x = output.multinomial(1)
samples.append(x)
else:
given_len = x.size(1)
lis = x.chunk(x.size(1), dim=1)
for i in range(given_len):
output, h, c = self.step(lis[i], h, c)
samples.append(lis[i])
x = output.multinomial(1)
for i in range(given_len, seq_len):
samples.append(x)
output, h, c = self.step(x, h, c)
x = output.multinomial(1)
output = torch.cat(samples, dim=1)
return output