-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathlstm_classify.py
70 lines (56 loc) · 1.79 KB
/
lstm_classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#
# From: https://github.com/TrWestdoor/pytorch-practice/blob/master/rnn_classify.py
#
import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
INPUT_SIZE = 28
BATCH_SIZE = 1
EPOCH = 1
LR = 0.005
DOWNLOAD_MNIST = False
train_data = datasets.MNIST(
root='./MNIST',
train=True,
transform=transforms.ToTensor(),
download=DOWNLOAD_MNIST,
)
train_loader = torch.utils.data.DataLoader(
dataset=train_data,
batch_size=BATCH_SIZE,
shuffle=True
)
test_data = datasets.MNIST(root='./MNIST', train=False, transform=transforms.ToTensor())
test_x = test_data.data.type(torch.FloatTensor)[:2000] / 255
test_y = test_data.targets.numpy()[:2000]
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.LSTM(
input_size=INPUT_SIZE,
hidden_size=64,
num_layers=1,
batch_first=True,
)
self.out = nn.Linear(self.rnn.hidden_size, 10)
def forward(self, x):
r_out, (h_c, h_h) = self.rnn(x, None)
out = self.out(r_out[:, -1, :])
return out
rnn = RNN()
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
loss_fun = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader):
b_x = b_x.view(-1, 28, 28)
r_out = rnn(b_x)
loss = loss_fun(r_out, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (step % 50) == 0:
test_out = rnn(test_x)
pred_y = torch.max(test_out, 1)[1].data.numpy()
accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
print("Epoch: ", epoch, "| train loss: ", loss.data.numpy(), "| test accuracy: %.2f %%" % (accuracy * 100.0))