Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rainlou #69

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
data/
checkpoint/
__pycache__/

# for pycharm
.idea/

# for log
*.log

# checkpoints are too big
saved_ckpt/

# tar
*.tar
22 changes: 13 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
'''Train CIFAR10 with PyTorch.'''

from __future__ import print_function

import torch
Expand All @@ -18,7 +20,7 @@


parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()

Expand Down Expand Up @@ -50,7 +52,7 @@

# Model
print('==> Building model..')
# net = VGG('VGG19')
net = VGG('VGG16')
# net = ResNet18()
# net = PreActResNet18()
# net = GoogLeNet()
Expand All @@ -61,7 +63,7 @@
# net = DPN92()
# net = ShuffleNetG2()
# net = SENet18()
net = ShuffleNetV2(1)
# net = ShuffleNetV2(1)
net = net.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)
Expand All @@ -74,10 +76,12 @@
checkpoint = torch.load('./checkpoint/ckpt.t7')
net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
start_epoch = checkpoint['epoch'] + 1

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
# criterion = nn.CrossEntropyLoss()
criterion = KLDivLoss(reduction='mean');
# optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=5e-4)

# Training
def train(epoch):
Expand All @@ -104,11 +108,11 @@ def train(epoch):

def test(epoch):
global best_acc
net.eval()
net.eval() # 变为测试模式, 对dropout和batch normalization有影响
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
with torch.no_grad(): # 运算不需要进行求导, 提高性能
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
Expand Down Expand Up @@ -137,6 +141,6 @@ def test(epoch):
best_acc = acc


for epoch in range(start_epoch, start_epoch+200):
for epoch in range(start_epoch, 61):
train(epoch)
test(epoch)
157 changes: 157 additions & 0 deletions main_L1Loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# -*- coding: utf-8 -*-
'''Train CIFAR10 with PyTorch.'''

from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import *
from utils import progress_bar


parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
net = VGG('VGG16')
# net = ResNet18()
# net = PreActResNet18()
# net = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
# net = MobileNet()
# net = MobileNetV2()
# net = DPN92()
# net = ShuffleNetG2()
# net = SENet18()
# net = ShuffleNetV2(1)
net = net.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)
cudnn.benchmark = True

if args.resume:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.t7')
net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch'] + 1

# criterion = nn.CrossEntropyLoss()
# criterion = nn.MSELoss(reduction='sum')
criterion = nn.L1Loss(reduction='sum');
# optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=5e-4)

# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
targets_mse = torch.zeros(outputs.size()).to(device)
for i in range(0, targets_mse.size()[0]):
targets_mse[i][targets[i]] = 1;
softmax = F.softmax(outputs, dim=0)
# print(targets_mse)
# print(outputs, targets)
loss = criterion(softmax, targets_mse)
loss.backward()
optimizer.step()

train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
global best_acc
net.eval() # 变为测试模式, 对dropout和batch normalization有影响
test_loss = 0
correct = 0
total = 0
with torch.no_grad(): # 运算不需要进行求导, 提高性能
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
targets_mse = torch.zeros(outputs.size()).to(device)
for i in range(0, targets_mse.size()[0]):
targets_mse[i][targets[i]] = 1;
softmax = F.softmax(outputs, dim=0)
loss = criterion(softmax, targets_mse)

test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

# Save checkpoint.
acc = 100.*correct/total
if acc > best_acc:
print('Saving..')
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.t7')
best_acc = acc


for epoch in range(start_epoch, 61):
train(epoch)
test(epoch)
Loading