diff --git a/main.py b/main.py index 05ca1eb90..9cf82284e 100644 --- a/main.py +++ b/main.py @@ -8,17 +8,25 @@ import torchvision import torchvision.transforms as transforms +import numpy as np + import os import argparse from models import * -from utils import progress_bar +from utils import progress_bar, rand_bbox parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') +parser.add_argument('--cosine', action='store_false', + help='use cosine annealing for lr') +parser.add_argument('--beta', default=1.0, type=float, + help='hyperparameter beta') +parser.add_argument('--cutmix_prob', default=0.5, type=float, + help='cutmix probability') args = parser.parse_args() device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -86,8 +94,10 @@ criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) -scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) - +if args.cosine: + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) +else: + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [50,100]) # Training def train(epoch): @@ -98,9 +108,28 @@ def train(epoch): total = 0 for batch_idx, (inputs, targets) in enumerate(trainloader): inputs, targets = inputs.to(device), targets.to(device) + r = np.random.rand(1) optimizer.zero_grad() - outputs = net(inputs) - loss = criterion(outputs, targets) + + if args.beta > 0 and r < args.cutmix_prob: + # sample lambda from beta distribution + lam = np.random.beta(args.beta, args.beta) + # get index of image to mix with current image + rand_index = torch.randperm(inputs.size()[0]).to(device) + target_a = targets + target_b = targets[rand_index] + # sample bounding box coordinates of binary mask + bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam) + inputs[:, :, bbx1:bbx2, bby1:bby2] = inputs[rand_index, :, bbx1:bbx2, bby1:bby2] + # adjust lambda to exactly match pixel ratio + lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (inputs.size()[-1] * inputs.size()[-2])) + # compute output + outputs = net(inputs) + loss = criterion(outputs, target_a) * lam + criterion(outputs, target_b) * (1. - lam) + else: + outputs = net(inputs) + loss = criterion(outputs, targets) + loss.backward() optimizer.step() diff --git a/utils.py b/utils.py index 4c9b3f90c..cb317ded0 100644 --- a/utils.py +++ b/utils.py @@ -7,6 +7,7 @@ import sys import time import math +import numpy as np import torch.nn as nn import torch.nn.init as init @@ -122,3 +123,22 @@ def format_time(seconds): if f == '': f = '0ms' return f + +def rand_bbox(size, lam): + '''Sample bounding box coordinates of binary mask for cutmix''' + W = size[2] + H = size[3] + cut_rat = np.sqrt(1. - lam) + cut_w = np.int(W * cut_rat) + cut_h = np.int(H * cut_rat) + + # uniform + cx = np.random.randint(W) + cy = np.random.randint(H) + + bbx1 = np.clip(cx - cut_w // 2, 0, W) + bby1 = np.clip(cy - cut_h // 2, 0, H) + bbx2 = np.clip(cx + cut_w // 2, 0, W) + bby2 = np.clip(cy + cut_h // 2, 0, H) + + return bbx1, bby1, bbx2, bby2