Skip to content

Commit

Permalink
update ae
Browse files Browse the repository at this point in the history
  • Loading branch information
jm12138 committed Feb 22, 2022
1 parent 63fa0b4 commit 438b2c2
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions examples/auto_encoder/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def get_argparser():


def main():
if not os.path.exists('results'):
os.mkdir('results')

opts = get_argparser().parse_args()

# dataset
Expand All @@ -58,7 +61,7 @@ def main():

val_loader = DataLoader(
ImageDataset(root='datasets/CLIC/valid', transform=val_transform),
batch_size=opts.batch_size, shuffle=False, num_workers=0)
batch_size=1, shuffle=False, num_workers=0)

print("Train set: %d, Val set: %d" %
(len(train_loader.dataset), len(val_loader.dataset)))
Expand Down Expand Up @@ -105,7 +108,7 @@ def main():
# ===== Validation =====
print("Val...")
best_score = 0.0
cur_score = test(opts, model, val_loader)
cur_score = test(opts, model, val_loader, cur_epoch)
print("%s = %.6f" % (opts.loss_type, cur_score))
# ===== Save Best Model =====
if cur_score > best_score: # save best model
Expand All @@ -114,7 +117,10 @@ def main():
print("Best model saved as best_model.pt")


def test(opts, model, val_loader):
def test(opts, model, val_loader, epoch):
save_dir = os.path.join('results', 'epoch_%d' % epoch)
if not os.path.exists(save_dir):
os.mkdir(save_dir)
model.eval()
cur_score = 0.0

Expand All @@ -124,11 +130,10 @@ def test(opts, model, val_loader):
for i, (images, ) in enumerate(val_loader):
outputs = model(images)
# save the first reconstructed image
if i == 20:
Image.fromarray((outputs*255).squeeze(0).detach().numpy().astype(
'uint8').transpose(1, 2, 0)).save('recons_%s.png' % (opts.loss_type))
cur_score += metric(outputs, images, data_range=1.0)
Image.fromarray((outputs*255).squeeze(0).detach().numpy().astype('uint8').transpose(1, 2, 0)).save(os.path.join(save_dir, 'recons_%s_%d.png' % (opts.loss_type, i)))
cur_score /= len(val_loader.dataset)

return cur_score


Expand Down

0 comments on commit 438b2c2

Please sign in to comment.