-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain2.py
122 lines (102 loc) · 3.9 KB
/
main2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
# Define the encoder part of the VAE
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, z_dim):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2_mu = nn.Linear(hidden_dim, z_dim)
self.fc2_logvar = nn.Linear(hidden_dim, z_dim)
def forward(self, x):
h = torch.relu(self.fc1(x))
mu = self.fc2_mu(h)
logvar = self.fc2_logvar(h)
return mu, logvar
# Define the decoder part of the VAE
class Decoder(nn.Module):
def __init__(self, z_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.fc1 = nn.Linear(z_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, z):
h = torch.relu(self.fc1(z))
x_recon = torch.sigmoid(self.fc2(h))
return x_recon
# Define the VAE model
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, z_dim):
super(VAE, self).__init__()
self.encoder = Encoder(input_dim, hidden_dim, z_dim)
self.decoder = Decoder(z_dim, hidden_dim, input_dim)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
x_recon = self.decoder(z)
return x_recon, mu, logvar
# Loss function
def loss_function(x_recon, x, mu, logvar):
BCE = nn.functional.binary_cross_entropy(x_recon, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
# Hyperparameters
input_dim = 28 * 28
hidden_dim = 400
z_dim = 20
batch_size = 128
learning_rate = 1e-3
num_epochs = 10
# Data loading
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# Model, optimizer
model = VAE(input_dim, hidden_dim, z_dim)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Training loop
model.train()
for epoch in range(num_epochs):
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(-1, 28 * 28) # Flatten the input data
optimizer.zero_grad()
x_recon, mu, logvar = model(data)
loss = loss_function(x_recon, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
print(f'Epoch {epoch + 1}, Loss: {train_loss / len(train_loader.dataset)}')
# Save the trained model
torch.save(model.state_dict(), 'vae.pth')
# Function to generate new samples from the trained VAE
def generate_samples(model, num_samples=10):
model.eval() # Set model to evaluation mode
with torch.no_grad():
z = torch.randn(num_samples, z_dim) # Sample random latent vectors
samples = model.decoder(z) # Decode to generate samples
samples = samples.view(-1, 1, 28, 28) # Reshape for visualization
return samples
# Load the trained model (for demonstration purposes)
model.load_state_dict(torch.load('vae.pth'))
# Generate new samples
num_samples = 10
generated_samples = generate_samples(model, num_samples)
# Print the shape of the generated samples
print(f'Generated samples shape: {generated_samples.shape}')
# Function to plot the generated samples
def plot_generated_samples(samples):
fig, axes = plt.subplots(1, len(samples), figsize=(15, 2))
for i, sample in enumerate(samples):
axes[i].imshow(sample.squeeze(), cmap='gray')
axes[i].axis('off')
plt.show()
# Plot the generated samples
plot_generated_samples(generated_samples)