From 3711a079ee3ddc90c948081bc1a7aa50539c4483 Mon Sep 17 00:00:00 2001 From: Sumanu Rawat Date: Fri, 19 Nov 2021 02:46:31 -0500 Subject: [PATCH 01/16] use resent 18 --- main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 05ca1eb90..8be2bc64e 100644 --- a/main.py +++ b/main.py @@ -55,7 +55,7 @@ # Model print('==> Building model..') # net = VGG('VGG19') -# net = ResNet18() +net = ResNet18() # net = PreActResNet18() # net = GoogLeNet() # net = DenseNet121() @@ -68,7 +68,7 @@ # net = ShuffleNetV2(1) # net = EfficientNetB0() # net = RegNetX_200MF() -net = SimpleDLA() +# net = SimpleDLA() net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) From 3eebc15fceaf6dc8a652b97b5105b4d90a7ce1c5 Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 15:10:08 -0500 Subject: [PATCH 02/16] pruning one shot --- main.py | 45 ++++++++++++++++++++++++++++++++++----------- prune_params.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 11 deletions(-) create mode 100644 prune_params.py diff --git a/main.py b/main.py index 8be2bc64e..2c62d3764 100644 --- a/main.py +++ b/main.py @@ -8,17 +8,21 @@ import torchvision import torchvision.transforms as transforms +import torch.nn.utils.prune as prune +from prune_params import get_prune_params import os import argparse from models import * from utils import progress_bar - parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') + +parser.add_argument('--prune_one_shot', '-pos', action='store_true', + help='resume from checkpoint with one shot pruning') args = parser.parse_args() device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -42,12 +46,12 @@ trainset = torchvision.datasets.CIFAR10( root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader( - trainset, batch_size=128, shuffle=True, num_workers=2) + trainset, batch_size=256, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10( root='./data', train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader( - testset, batch_size=100, shuffle=False, num_workers=2) + testset, batch_size=256, shuffle=False, num_workers=2) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') @@ -88,6 +92,15 @@ momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) +if args.prune_one_shot: + print('Perform one shot pruning and retraining') + assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' + checkpoint = torch.load('./checkpoint/ckpt_prune_one_shot.pth') + net.load_state_dict(checkpoint['net']) + best_acc = checkpoint['acc'] + start_epoch = checkpoint['epoch'] + + # Training def train(epoch): @@ -110,7 +123,7 @@ def train(epoch): correct += predicted.eq(targets).sum().item() progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' - % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) + % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) def test(epoch): @@ -131,10 +144,10 @@ def test(epoch): correct += predicted.eq(targets).sum().item() progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' - % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) + % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) # Save checkpoint. - acc = 100.*correct/total + acc = 100. * correct / total if acc > best_acc: print('Saving..') state = { @@ -144,11 +157,21 @@ def test(epoch): } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') - torch.save(state, './checkpoint/ckpt.pth') + if args.prune_one_shot: + torch.save(state, './checkpoint/ckpt_prune_one_shot.pth') + else: + torch.save(state, './checkpoint/ckpt.pth') best_acc = acc +if __name__ == '__main__': + + if args.prune_one_shot: + print('one shot pruning in main') + parameters_to_prune = get_prune_params(net) + prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, importance_scores=None, + amount=0.9) -for epoch in range(start_epoch, start_epoch+200): - train(epoch) - test(epoch) - scheduler.step() + for epoch in range(start_epoch, start_epoch + 1): + train(epoch) + test(epoch) + scheduler.step() diff --git a/prune_params.py b/prune_params.py new file mode 100644 index 000000000..5ff2fe5c5 --- /dev/null +++ b/prune_params.py @@ -0,0 +1,49 @@ +def get_prune_params(net): + parameters_to_prune = ( + (net.module.conv1, 'weight'), + (net.module.bn1, 'weight'), + + (net.module.layer1[0].conv1, 'weight'), + (net.module.layer1[0].bn1, 'weight'), + (net.module.layer1[0].conv2, 'weight'), + (net.module.layer1[0].bn2, 'weight'), + (net.module.layer1[1].conv1, 'weight'), + (net.module.layer1[1].bn1, 'weight'), + (net.module.layer1[1].conv2, 'weight'), + (net.module.layer1[1].bn2, 'weight'), + + (net.module.layer2[0].conv1, 'weight'), + (net.module.layer2[0].bn1, 'weight'), + (net.module.layer2[0].conv2, 'weight'), + (net.module.layer2[0].bn2, 'weight'), + (net.module.layer2[0].shortcut[0], 'weight'), + (net.module.layer2[0].shortcut[1], 'weight'), + (net.module.layer2[1].conv1, 'weight'), + (net.module.layer2[1].bn1, 'weight'), + (net.module.layer2[1].conv2, 'weight'), + (net.module.layer2[1].bn2, 'weight'), + + (net.module.layer3[0].conv1, 'weight'), + (net.module.layer3[0].bn1, 'weight'), + (net.module.layer3[0].conv2, 'weight'), + (net.module.layer3[0].bn2, 'weight'), + (net.module.layer3[0].shortcut[0], 'weight'), + (net.module.layer3[0].shortcut[1], 'weight'), + (net.module.layer3[1].conv1, 'weight'), + (net.module.layer3[1].bn1, 'weight'), + (net.module.layer3[1].conv2, 'weight'), + (net.module.layer3[1].bn2, 'weight'), + + (net.module.layer4[0].conv1, 'weight'), + (net.module.layer4[0].bn1, 'weight'), + (net.module.layer4[0].conv2, 'weight'), + (net.module.layer4[0].bn2, 'weight'), + (net.module.layer4[0].shortcut[0], 'weight'), + (net.module.layer4[0].shortcut[1], 'weight'), + (net.module.layer4[1].conv1, 'weight'), + (net.module.layer4[1].bn1, 'weight'), + (net.module.layer4[1].conv2, 'weight'), + (net.module.layer4[1].bn2, 'weight'), + + ) + return parameters_to_prune \ No newline at end of file From 7a53df3e368b4932e11514bb983a060253ad7639 Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 15:19:48 -0500 Subject: [PATCH 03/16] model path --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 2c62d3764..59e66dd39 100644 --- a/main.py +++ b/main.py @@ -95,7 +95,7 @@ if args.prune_one_shot: print('Perform one shot pruning and retraining') assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' - checkpoint = torch.load('./checkpoint/ckpt_prune_one_shot.pth') + checkpoint = torch.load('./checkpoint/ckpt.pth') net.load_state_dict(checkpoint['net']) best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] From eb4bd2bae8459210f2f850d766f3383a8c609562 Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 15:28:02 -0500 Subject: [PATCH 04/16] epochs --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 59e66dd39..4333ba921 100644 --- a/main.py +++ b/main.py @@ -171,7 +171,7 @@ def test(epoch): prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, importance_scores=None, amount=0.9) - for epoch in range(start_epoch, start_epoch + 1): + for epoch in range(start_epoch, start_epoch + 200): train(epoch) test(epoch) scheduler.step() From afa8caf44c7a6b9724c3d0f73c51dc83c52b44b6 Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 15:39:11 -0500 Subject: [PATCH 05/16] saving one shot pruned model --- main.py | 47 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/main.py b/main.py index 4333ba921..029fcbdde 100644 --- a/main.py +++ b/main.py @@ -28,6 +28,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu' best_acc = 0 # best test accuracy start_epoch = 0 # start from epoch 0 or last checkpoint epoch +pos_best_acc = 0 # best accuracy for one shot pruned model # Data print('==> Preparing data..') @@ -85,6 +86,7 @@ checkpoint = torch.load('./checkpoint/ckpt.pth') net.load_state_dict(checkpoint['net']) best_acc = checkpoint['acc'] + pos_best_acc = checkpoint['pos_best_acc'] start_epoch = checkpoint['epoch'] criterion = nn.CrossEntropyLoss() @@ -98,10 +100,10 @@ checkpoint = torch.load('./checkpoint/ckpt.pth') net.load_state_dict(checkpoint['net']) best_acc = checkpoint['acc'] + pos_best_acc = checkpoint['pos_best_acc'] start_epoch = checkpoint['epoch'] - # Training def train(epoch): print('\nEpoch: %d' % epoch) @@ -128,6 +130,7 @@ def train(epoch): def test(epoch): global best_acc + global pos_best_acc net.eval() test_loss = 0 correct = 0 @@ -147,21 +150,37 @@ def test(epoch): % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) # Save checkpoint. - acc = 100. * correct / total - if acc > best_acc: - print('Saving..') - state = { - 'net': net.state_dict(), - 'acc': acc, - 'epoch': epoch, - } - if not os.path.isdir('checkpoint'): - os.mkdir('checkpoint') - if args.prune_one_shot: + if args.prune_one_shot: + acc = 100. * correct / total + if acc > pos_best_acc: + print('Saving..') + state = { + 'net': net.state_dict(), + 'acc': acc, + 'epoch': epoch, + 'pos_best_acc': pos_best_acc + } + if not os.path.isdir('checkpoint'): + os.mkdir('checkpoint') torch.save(state, './checkpoint/ckpt_prune_one_shot.pth') - else: + pos_best_acc = acc + + else: + acc = 100. * correct / total + if acc > best_acc: + print('Saving..') + state = { + 'net': net.state_dict(), + 'acc': acc, + 'epoch': epoch, + 'pos_best_acc': pos_best_acc + } + if not os.path.isdir('checkpoint'): + os.mkdir('checkpoint') + torch.save(state, './checkpoint/ckpt.pth') - best_acc = acc + best_acc = acc + if __name__ == '__main__': From 44a56611d95de1f0c4c9f3eb38b02742256f04d0 Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 15:42:42 -0500 Subject: [PATCH 06/16] correction --- main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 029fcbdde..fa8fd0d4d 100644 --- a/main.py +++ b/main.py @@ -86,7 +86,7 @@ checkpoint = torch.load('./checkpoint/ckpt.pth') net.load_state_dict(checkpoint['net']) best_acc = checkpoint['acc'] - pos_best_acc = checkpoint['pos_best_acc'] + #pos_best_acc = checkpoint['pos_best_acc'] start_epoch = checkpoint['epoch'] criterion = nn.CrossEntropyLoss() @@ -100,7 +100,7 @@ checkpoint = torch.load('./checkpoint/ckpt.pth') net.load_state_dict(checkpoint['net']) best_acc = checkpoint['acc'] - pos_best_acc = checkpoint['pos_best_acc'] + #pos_best_acc = checkpoint['pos_best_acc'] start_epoch = checkpoint['epoch'] @@ -164,7 +164,7 @@ def test(epoch): os.mkdir('checkpoint') torch.save(state, './checkpoint/ckpt_prune_one_shot.pth') pos_best_acc = acc - + else: acc = 100. * correct / total if acc > best_acc: From a7043058d61ba964b25a6e6f9b01342d85bb068c Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 17:10:00 -0500 Subject: [PATCH 07/16] prune remove --- main.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index fa8fd0d4d..b1cb9e707 100644 --- a/main.py +++ b/main.py @@ -86,7 +86,7 @@ checkpoint = torch.load('./checkpoint/ckpt.pth') net.load_state_dict(checkpoint['net']) best_acc = checkpoint['acc'] - #pos_best_acc = checkpoint['pos_best_acc'] + # pos_best_acc = checkpoint['pos_best_acc'] start_epoch = checkpoint['epoch'] criterion = nn.CrossEntropyLoss() @@ -100,7 +100,7 @@ checkpoint = torch.load('./checkpoint/ckpt.pth') net.load_state_dict(checkpoint['net']) best_acc = checkpoint['acc'] - #pos_best_acc = checkpoint['pos_best_acc'] + # pos_best_acc = checkpoint['pos_best_acc'] start_epoch = checkpoint['epoch'] @@ -150,7 +150,11 @@ def test(epoch): % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) # Save checkpoint. - if args.prune_one_shot: + # if args.prune_one_shot: + if epoch == 5: + prune_params = get_prune_params(net) + for prune_param in prune_params: + prune.remove(prune_param, 'weight') acc = 100. * correct / total if acc > pos_best_acc: print('Saving..') @@ -158,7 +162,7 @@ def test(epoch): 'net': net.state_dict(), 'acc': acc, 'epoch': epoch, - 'pos_best_acc': pos_best_acc + 'pos_best_acc': pos_best_acc, } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') From fb573a82c27d63086e46e31df3fa07dd92b0a14e Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 17:48:48 -0500 Subject: [PATCH 08/16] permanent pruning --- main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/main.py b/main.py index b1cb9e707..e2e7a13d8 100644 --- a/main.py +++ b/main.py @@ -152,9 +152,12 @@ def test(epoch): # Save checkpoint. # if args.prune_one_shot: if epoch == 5: + + # make pruning permanent prune_params = get_prune_params(net) for prune_param in prune_params: prune.remove(prune_param, 'weight') + acc = 100. * correct / total if acc > pos_best_acc: print('Saving..') From 3d2cac823c8375af52758e06d9c38e4dc613f56c Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 17:58:40 -0500 Subject: [PATCH 09/16] epoch --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index e2e7a13d8..6e8cf54b0 100644 --- a/main.py +++ b/main.py @@ -151,7 +151,7 @@ def test(epoch): # Save checkpoint. # if args.prune_one_shot: - if epoch == 5: + if epoch == 202: # make pruning permanent prune_params = get_prune_params(net) From 7fbc8bd43aef5b1f4236f25715d1f09811c2c514 Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 18:07:55 -0500 Subject: [PATCH 10/16] 0th param --- main.py | 4 ++-- prune_params.py | 22 +++++++++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 6e8cf54b0..8c785360c 100644 --- a/main.py +++ b/main.py @@ -151,12 +151,12 @@ def test(epoch): # Save checkpoint. # if args.prune_one_shot: - if epoch == 202: + if epoch == 201: # make pruning permanent prune_params = get_prune_params(net) for prune_param in prune_params: - prune.remove(prune_param, 'weight') + prune.remove(prune_param[0], 'weight') acc = 100. * correct / total if acc > pos_best_acc: diff --git a/prune_params.py b/prune_params.py index 5ff2fe5c5..2831b1548 100644 --- a/prune_params.py +++ b/prune_params.py @@ -46,4 +46,24 @@ def get_prune_params(net): (net.module.layer4[1].bn2, 'weight'), ) - return parameters_to_prune \ No newline at end of file + return parameters_to_prune + +def print_sparsity(model): + print( + "Global sparsity: {:.2f}%".format( + 100. * float( + torch.sum(model.conv1.weight == 0) + + torch.sum(model.conv2.weight == 0) + + torch.sum(model.fc1.weight == 0) + + torch.sum(model.fc2.weight == 0) + + torch.sum(model.fc3.weight == 0) + ) + / float( + model.conv1.weight.nelement() + + model.conv2.weight.nelement() + + model.fc1.weight.nelement() + + model.fc2.weight.nelement() + + model.fc3.weight.nelement() + ) + ) + ) \ No newline at end of file From 606a05855d289549eb5bb04154275471a7293273 Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 18:13:29 -0500 Subject: [PATCH 11/16] sparsity function --- main.py | 2 +- prune_params.py | 29 +++++++++++------------------ 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index 8c785360c..5dd500ae7 100644 --- a/main.py +++ b/main.py @@ -151,7 +151,7 @@ def test(epoch): # Save checkpoint. # if args.prune_one_shot: - if epoch == 201: + if epoch == 199: # make pruning permanent prune_params = get_prune_params(net) diff --git a/prune_params.py b/prune_params.py index 2831b1548..94df643f4 100644 --- a/prune_params.py +++ b/prune_params.py @@ -1,3 +1,6 @@ +import torch + + def get_prune_params(net): parameters_to_prune = ( (net.module.conv1, 'weight'), @@ -48,22 +51,12 @@ def get_prune_params(net): ) return parameters_to_prune + def print_sparsity(model): - print( - "Global sparsity: {:.2f}%".format( - 100. * float( - torch.sum(model.conv1.weight == 0) - + torch.sum(model.conv2.weight == 0) - + torch.sum(model.fc1.weight == 0) - + torch.sum(model.fc2.weight == 0) - + torch.sum(model.fc3.weight == 0) - ) - / float( - model.conv1.weight.nelement() - + model.conv2.weight.nelement() - + model.fc1.weight.nelement() - + model.fc2.weight.nelement() - + model.fc3.weight.nelement() - ) - ) - ) \ No newline at end of file + params = get_prune_params(model) + zero_weights = 0 + total_weigts = 0 + for param in params: + zero_weights += torch.sum(param[0].weight == 0) + total_weigts += param[0].weight.nelement() + print("Global sparsity: {:.2f}%".format(100. * zero_weights / total_weigts)) From 02117b336042f6536f864ea0ade2e3361e6c8ec6 Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 18:20:07 -0500 Subject: [PATCH 12/16] print sparsity --- main.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 5dd500ae7..aa8713295 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,7 @@ import torchvision.transforms as transforms import torch.nn.utils.prune as prune -from prune_params import get_prune_params +from prune_params import get_prune_params, print_sparsity import os import argparse @@ -150,9 +150,7 @@ def test(epoch): % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) # Save checkpoint. - # if args.prune_one_shot: - if epoch == 199: - + if args.prune_one_shot: # make pruning permanent prune_params = get_prune_params(net) for prune_param in prune_params: @@ -171,6 +169,7 @@ def test(epoch): os.mkdir('checkpoint') torch.save(state, './checkpoint/ckpt_prune_one_shot.pth') pos_best_acc = acc + print_sparsity(net) else: acc = 100. * correct / total @@ -192,6 +191,8 @@ def test(epoch): if __name__ == '__main__': if args.prune_one_shot: + print('spartity at the start') + print_sparsity(net) print('one shot pruning in main') parameters_to_prune = get_prune_params(net) prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, importance_scores=None, From 4ae293d52b49c22c495b75a0395423dc8d0c1fdd Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 18:37:41 -0500 Subject: [PATCH 13/16] save one shot pruned model --- main.py | 43 +++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/main.py b/main.py index aa8713295..a13a02519 100644 --- a/main.py +++ b/main.py @@ -23,6 +23,11 @@ parser.add_argument('--prune_one_shot', '-pos', action='store_true', help='resume from checkpoint with one shot pruning') + +parser.add_argument('--prune_amount', '-pr', action='store_true', + help='resume from checkpoint with one shot pruning') +parser.add_argument('-pa', default=0, type=float, help='pruning amount') + args = parser.parse_args() device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -54,26 +59,15 @@ testloader = torch.utils.data.DataLoader( testset, batch_size=256, shuffle=False, num_workers=2) +model_save_path = './checkpoint/ckpt.pth' +prune_amount = 0 + classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # Model print('==> Building model..') -# net = VGG('VGG19') net = ResNet18() -# net = PreActResNet18() -# net = GoogLeNet() -# net = DenseNet121() -# net = ResNeXt29_2x64d() -# net = MobileNet() -# net = MobileNetV2() -# net = DPN92() -# net = ShuffleNetG2() -# net = SENet18() -# net = ShuffleNetV2(1) -# net = EfficientNetB0() -# net = RegNetX_200MF() -# net = SimpleDLA() net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) @@ -151,13 +145,14 @@ def test(epoch): # Save checkpoint. if args.prune_one_shot: - # make pruning permanent - prune_params = get_prune_params(net) - for prune_param in prune_params: - prune.remove(prune_param[0], 'weight') acc = 100. * correct / total if acc > pos_best_acc: + # Remove pruning before saving + prune_params = get_prune_params(net) + for prune_param in prune_params: + prune.remove(prune_param[0], 'weight') + print('Saving..') state = { 'net': net.state_dict(), @@ -167,10 +162,15 @@ def test(epoch): } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') - torch.save(state, './checkpoint/ckpt_prune_one_shot.pth') + torch.save(state, './checkpoint/ckpt_prune_one_shot_' + str(int(100 * prune_amount)) + '.pth') pos_best_acc = acc print_sparsity(net) + # apply pruning masks back before continuing (this will be the same since model is already pruned) + prune.global_unstructured(get_prune_params(net), pruning_method=prune.L1Unstructured, + importance_scores=None, amount=prune_amount) + print_sparsity(net) + else: acc = 100. * correct / total if acc > best_acc: @@ -193,10 +193,13 @@ def test(epoch): if args.prune_one_shot: print('spartity at the start') print_sparsity(net) + + prune_amount = args.pa + print('one shot pruning in main') parameters_to_prune = get_prune_params(net) prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, importance_scores=None, - amount=0.9) + amount=prune_amount) for epoch in range(start_epoch, start_epoch + 200): train(epoch) From b68b140eda5185536c64ca00a8b58fa5b202c230 Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 19:10:46 -0500 Subject: [PATCH 14/16] iterative pruning --- main.py | 72 +++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 60 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index a13a02519..6233216f8 100644 --- a/main.py +++ b/main.py @@ -24,6 +24,9 @@ parser.add_argument('--prune_one_shot', '-pos', action='store_true', help='resume from checkpoint with one shot pruning') +parser.add_argument('--prune_iterative', '-pit', action='store_true', + help='resume from checkpoint with iterative pruning') + parser.add_argument('--prune_amount', '-pr', action='store_true', help='resume from checkpoint with one shot pruning') parser.add_argument('-pa', default=0, type=float, help='pruning amount') @@ -144,7 +147,33 @@ def test(epoch): % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) # Save checkpoint. - if args.prune_one_shot: + if args.prune_iterative: + acc = 100. * correct / total + if acc > pos_best_acc: + # Remove pruning before saving + prune_params = get_prune_params(net) + for prune_param in prune_params: + prune.remove(prune_param[0], 'weight') + + print('Saving..') + state = { + 'net': net.state_dict(), + 'acc': acc, + 'epoch': epoch, + 'pos_best_acc': pos_best_acc, + } + if not os.path.isdir('checkpoint'): + os.mkdir('checkpoint') + torch.save(state, './checkpoint/ckpt_prune_iterative_' + str(int(100 * prune_amount)) + '.pth') + pos_best_acc = acc + print_sparsity(net) + + # apply pruning masks back before continuing (this will be the same since model is already pruned) + prune.global_unstructured(get_prune_params(net), pruning_method=prune.L1Unstructured, + importance_scores=None, amount=prune_amount) + print_sparsity(net) + + elif args.prune_one_shot: acc = 100. * correct / total if acc > pos_best_acc: @@ -189,19 +218,38 @@ def test(epoch): if __name__ == '__main__': - - if args.prune_one_shot: - print('spartity at the start') - print_sparsity(net) - + # num_epoch_train, num_epoch_one_shot, num_epoch_iterative = (200, 100, 25) + num_epoch_train, num_epoch_one_shot, num_epoch_iterative = (4, 4, 4) + # Iterative pruning + if args.prune_iterative: + total_prune_amount = args.pa + + num_pruning_iter = 4 + # increase the pruning amount over num_pruning_iter iterations + for prune_x in range(num_pruning_iter): + prune_amount = (prune_x + 1) * total_prune_amount / num_pruning_iter + parameters_to_prune = get_prune_params(net) + prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, importance_scores=None, + amount=prune_amount) + for epoch in range(start_epoch, start_epoch + num_epoch_iterative): + train(epoch) + test(epoch) + scheduler.step() + + # One shot pruning and retraining + elif args.prune_one_shot: prune_amount = args.pa - - print('one shot pruning in main') parameters_to_prune = get_prune_params(net) prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, importance_scores=None, amount=prune_amount) + for epoch in range(start_epoch, start_epoch + num_epoch_one_shot): + train(epoch) + test(epoch) + scheduler.step() - for epoch in range(start_epoch, start_epoch + 200): - train(epoch) - test(epoch) - scheduler.step() + # No pruning + else: + for epoch in range(start_epoch, start_epoch + num_epoch_train): + train(epoch) + test(epoch) + scheduler.step() From 6ade5c700e58b286e82bcfb65403f9f1b1414408 Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 19:37:19 -0500 Subject: [PATCH 15/16] cleaning --- main.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/main.py b/main.py index 6233216f8..96fb9c0fb 100644 --- a/main.py +++ b/main.py @@ -171,7 +171,6 @@ def test(epoch): # apply pruning masks back before continuing (this will be the same since model is already pruned) prune.global_unstructured(get_prune_params(net), pruning_method=prune.L1Unstructured, importance_scores=None, amount=prune_amount) - print_sparsity(net) elif args.prune_one_shot: @@ -198,7 +197,6 @@ def test(epoch): # apply pruning masks back before continuing (this will be the same since model is already pruned) prune.global_unstructured(get_prune_params(net), pruning_method=prune.L1Unstructured, importance_scores=None, amount=prune_amount) - print_sparsity(net) else: acc = 100. * correct / total From da803beec7a042ff8300137fb96130b80c3142bf Mon Sep 17 00:00:00 2001 From: sumanurawat umass Date: Fri, 19 Nov 2021 22:27:33 -0500 Subject: [PATCH 16/16] FINAL EPOCHS --- main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 96fb9c0fb..e23a4562d 100644 --- a/main.py +++ b/main.py @@ -216,8 +216,8 @@ def test(epoch): if __name__ == '__main__': - # num_epoch_train, num_epoch_one_shot, num_epoch_iterative = (200, 100, 25) - num_epoch_train, num_epoch_one_shot, num_epoch_iterative = (4, 4, 4) + num_epoch_train, num_epoch_one_shot, num_epoch_iterative = (200, 100, 25) + # num_epoch_train, num_epoch_one_shot, num_epoch_iterative = (4, 4, 4) # Iterative pruning if args.prune_iterative: total_prune_amount = args.pa