Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Srawat dev #141

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
175 changes: 137 additions & 38 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,35 @@
import torchvision
import torchvision.transforms as transforms

import torch.nn.utils.prune as prune
from prune_params import get_prune_params, print_sparsity
import os
import argparse

from models import *
from utils import progress_bar


parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true',
help='resume from checkpoint')

parser.add_argument('--prune_one_shot', '-pos', action='store_true',
help='resume from checkpoint with one shot pruning')

parser.add_argument('--prune_iterative', '-pit', action='store_true',
help='resume from checkpoint with iterative pruning')

parser.add_argument('--prune_amount', '-pr', action='store_true',
help='resume from checkpoint with one shot pruning')
parser.add_argument('-pa', default=0, type=float, help='pruning amount')

args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
pos_best_acc = 0 # best accuracy for one shot pruned model

# Data
print('==> Preparing data..')
Expand All @@ -42,33 +55,22 @@
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=128, shuffle=True, num_workers=2)
trainset, batch_size=256, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
testset, batch_size=100, shuffle=False, num_workers=2)
testset, batch_size=256, shuffle=False, num_workers=2)

model_save_path = './checkpoint/ckpt.pth'
prune_amount = 0

classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
# net = VGG('VGG19')
# net = ResNet18()
# net = PreActResNet18()
# net = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
# net = MobileNet()
# net = MobileNetV2()
# net = DPN92()
# net = ShuffleNetG2()
# net = SENet18()
# net = ShuffleNetV2(1)
# net = EfficientNetB0()
# net = RegNetX_200MF()
net = SimpleDLA()
net = ResNet18()
net = net.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)
Expand All @@ -81,13 +83,23 @@
checkpoint = torch.load('./checkpoint/ckpt.pth')
net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
# pos_best_acc = checkpoint['pos_best_acc']
start_epoch = checkpoint['epoch']

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr,
momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

if args.prune_one_shot:
print('Perform one shot pruning and retraining')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.pth')
net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
# pos_best_acc = checkpoint['pos_best_acc']
start_epoch = checkpoint['epoch']


# Training
def train(epoch):
Expand All @@ -110,11 +122,12 @@ def train(epoch):
correct += predicted.eq(targets).sum().item()

progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
% (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))


def test(epoch):
global best_acc
global pos_best_acc
net.eval()
test_loss = 0
correct = 0
Expand All @@ -131,24 +144,110 @@ def test(epoch):
correct += predicted.eq(targets).sum().item()

progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))

# Save checkpoint.
acc = 100.*correct/total
if acc > best_acc:
print('Saving..')
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.pth')
best_acc = acc


for epoch in range(start_epoch, start_epoch+200):
train(epoch)
test(epoch)
scheduler.step()
if args.prune_iterative:
acc = 100. * correct / total
if acc > pos_best_acc:
# Remove pruning before saving
prune_params = get_prune_params(net)
for prune_param in prune_params:
prune.remove(prune_param[0], 'weight')

print('Saving..')
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
'pos_best_acc': pos_best_acc,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt_prune_iterative_' + str(int(100 * prune_amount)) + '.pth')
pos_best_acc = acc
print_sparsity(net)

# apply pruning masks back before continuing (this will be the same since model is already pruned)
prune.global_unstructured(get_prune_params(net), pruning_method=prune.L1Unstructured,
importance_scores=None, amount=prune_amount)

elif args.prune_one_shot:

acc = 100. * correct / total
if acc > pos_best_acc:
# Remove pruning before saving
prune_params = get_prune_params(net)
for prune_param in prune_params:
prune.remove(prune_param[0], 'weight')

print('Saving..')
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
'pos_best_acc': pos_best_acc,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt_prune_one_shot_' + str(int(100 * prune_amount)) + '.pth')
pos_best_acc = acc
print_sparsity(net)

# apply pruning masks back before continuing (this will be the same since model is already pruned)
prune.global_unstructured(get_prune_params(net), pruning_method=prune.L1Unstructured,
importance_scores=None, amount=prune_amount)

else:
acc = 100. * correct / total
if acc > best_acc:
print('Saving..')
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
'pos_best_acc': pos_best_acc
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')

torch.save(state, './checkpoint/ckpt.pth')
best_acc = acc


if __name__ == '__main__':
num_epoch_train, num_epoch_one_shot, num_epoch_iterative = (200, 100, 25)
# num_epoch_train, num_epoch_one_shot, num_epoch_iterative = (4, 4, 4)
# Iterative pruning
if args.prune_iterative:
total_prune_amount = args.pa

num_pruning_iter = 4
# increase the pruning amount over num_pruning_iter iterations
for prune_x in range(num_pruning_iter):
prune_amount = (prune_x + 1) * total_prune_amount / num_pruning_iter
parameters_to_prune = get_prune_params(net)
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, importance_scores=None,
amount=prune_amount)
for epoch in range(start_epoch, start_epoch + num_epoch_iterative):
train(epoch)
test(epoch)
scheduler.step()

# One shot pruning and retraining
elif args.prune_one_shot:
prune_amount = args.pa
parameters_to_prune = get_prune_params(net)
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, importance_scores=None,
amount=prune_amount)
for epoch in range(start_epoch, start_epoch + num_epoch_one_shot):
train(epoch)
test(epoch)
scheduler.step()

# No pruning
else:
for epoch in range(start_epoch, start_epoch + num_epoch_train):
train(epoch)
test(epoch)
scheduler.step()
62 changes: 62 additions & 0 deletions prune_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch


def get_prune_params(net):
parameters_to_prune = (
(net.module.conv1, 'weight'),
(net.module.bn1, 'weight'),

(net.module.layer1[0].conv1, 'weight'),
(net.module.layer1[0].bn1, 'weight'),
(net.module.layer1[0].conv2, 'weight'),
(net.module.layer1[0].bn2, 'weight'),
(net.module.layer1[1].conv1, 'weight'),
(net.module.layer1[1].bn1, 'weight'),
(net.module.layer1[1].conv2, 'weight'),
(net.module.layer1[1].bn2, 'weight'),

(net.module.layer2[0].conv1, 'weight'),
(net.module.layer2[0].bn1, 'weight'),
(net.module.layer2[0].conv2, 'weight'),
(net.module.layer2[0].bn2, 'weight'),
(net.module.layer2[0].shortcut[0], 'weight'),
(net.module.layer2[0].shortcut[1], 'weight'),
(net.module.layer2[1].conv1, 'weight'),
(net.module.layer2[1].bn1, 'weight'),
(net.module.layer2[1].conv2, 'weight'),
(net.module.layer2[1].bn2, 'weight'),

(net.module.layer3[0].conv1, 'weight'),
(net.module.layer3[0].bn1, 'weight'),
(net.module.layer3[0].conv2, 'weight'),
(net.module.layer3[0].bn2, 'weight'),
(net.module.layer3[0].shortcut[0], 'weight'),
(net.module.layer3[0].shortcut[1], 'weight'),
(net.module.layer3[1].conv1, 'weight'),
(net.module.layer3[1].bn1, 'weight'),
(net.module.layer3[1].conv2, 'weight'),
(net.module.layer3[1].bn2, 'weight'),

(net.module.layer4[0].conv1, 'weight'),
(net.module.layer4[0].bn1, 'weight'),
(net.module.layer4[0].conv2, 'weight'),
(net.module.layer4[0].bn2, 'weight'),
(net.module.layer4[0].shortcut[0], 'weight'),
(net.module.layer4[0].shortcut[1], 'weight'),
(net.module.layer4[1].conv1, 'weight'),
(net.module.layer4[1].bn1, 'weight'),
(net.module.layer4[1].conv2, 'weight'),
(net.module.layer4[1].bn2, 'weight'),

)
return parameters_to_prune


def print_sparsity(model):
params = get_prune_params(model)
zero_weights = 0
total_weigts = 0
for param in params:
zero_weights += torch.sum(param[0].weight == 0)
total_weigts += param[0].weight.nelement()
print("Global sparsity: {:.2f}%".format(100. * zero_weights / total_weigts))