From ac7f7c348a2c5eec3f473da4b22619a79efb963d Mon Sep 17 00:00:00 2001 From: Khaled Saab Date: Tue, 30 Oct 2018 12:04:27 -0700 Subject: [PATCH] added CIFAR tutorials --- tutorials/CIFAR_test.py | 135 ++++++++++++++++++++++++++++++++++++++++ tutorials/resnet.py | 117 ++++++++++++++++++++++++++++++++++ 2 files changed, 252 insertions(+) create mode 100644 tutorials/CIFAR_test.py create mode 100644 tutorials/resnet.py diff --git a/tutorials/CIFAR_test.py b/tutorials/CIFAR_test.py new file mode 100644 index 00000000..0156e58e --- /dev/null +++ b/tutorials/CIFAR_test.py @@ -0,0 +1,135 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data.dataloader as dataloader +import torch.optim as optim + +from torch.autograd import Variable +from torchvision import transforms +from torchvision.datasets import CIFAR10 + +from resnet import ResNet18 +from metal import EndModel +from metal.utils import convert_labels +from torch.utils.data import Dataset + +import argparse + + +SEED = 1 + +# CUDA? +cuda = torch.cuda.is_available() + +# For reproducibility +torch.manual_seed(SEED) + +if cuda: + torch.cuda.manual_seed(SEED) + + +parser = argparse.ArgumentParser(description='PyTorch Training') + +parser.add_argument('--epochs', default=10, type=int, + help='number of total epochs to run') + +parser.add_argument('--start-epoch', default=0, type=int, + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=10, type=int, + help='mini-batch size (default: 1)') + +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + help='initial learning rate') + +parser.add_argument('--momentum', default=0.9, type=float, help='momentum') +parser.add_argument('--weight-decay', '--wd', default=0, type=float, + help='weight decay (default: 1e-4)') +parser.add_argument('--print-freq', '-p', default=10, type=int, + help='print frequency (default: 10)') + + +# The following identity module is to essentially replace the last FC layer +# in the resnet model by the FC in MeTal + +class IdentityModule(nn.Module): + """A default identity input module that simply passes the input through.""" + + def __init__(self): + super().__init__() + + def reset_parameters(self): + pass + + def forward(self, x): + return x + +# Here we create a dataloader that transforms CIFAR labels from 0-9, to 1-10, +# We do this because MeTal treats a 0 label as abstain +class MetalDataset(Dataset): + """A dataset that group each item in X with it label from Y + + Args: + X: an n-dim iterable of items + Y: a torch.Tensor of labels + This may be hard labels [n] or soft labels [n, k] + """ + + def __init__(self, dataset): + self.dataset = dataset + #Y = convert_labels(Y,'onezero','categorical') + #self.Y = Y + #assert len(X) == len(Y) + + def __getitem__(self, index): + x,y = self.dataset[index] + # convert to metal form + y += 1 + return tuple([x,y]) + + def __len__(self): + return len(self.dataset) + + + +def train_model(): + + global args + args = parser.parse_args() + + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_train) + train_loader = dataloader.DataLoader(MetalDataset(trainset), batch_size=128, shuffle=True, num_workers=2) + + + testset = CIFAR10(root='./data', train=False, download=True, transform=transform_test) + test_loader = dataloader.DataLoader(MetalDataset(testset), batch_size=100, shuffle=False, num_workers=2) + + classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') + + + model = ResNet18() + model.linear = IdentityModule() + + end_model = EndModel([512,10], input_module=model, seed=123, use_cuda=True, relu=False) + + end_model.train_model(train_data=train_loader, dev_data=test_loader, l2=args.weight_decay, lr=args.lr, n_epochs=args.epochs, print_every=1, validation_metric='accuracy') + + end_model.score(test_loader, metric=['accuracy', 'precision', 'recall', 'f1']) + + + +if __name__ == "__main__": + train_model() + \ No newline at end of file diff --git a/tutorials/resnet.py b/tutorials/resnet.py new file mode 100644 index 00000000..22980121 --- /dev/null +++ b/tutorials/resnet.py @@ -0,0 +1,117 @@ +'''ResNet in PyTorch. +For Pre-activation ResNet, see 'preact_resnet.py'. +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion*planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512*block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def ResNet18(): + return ResNet(BasicBlock, [2,2,2,2]) + +def ResNet34(): + return ResNet(BasicBlock, [3,4,6,3]) + +def ResNet50(): + return ResNet(Bottleneck, [3,4,6,3]) + +def ResNet101(): + return ResNet(Bottleneck, [3,4,23,3]) + +def ResNet152(): + return ResNet(Bottleneck, [3,8,36,3]) + + +def test(): + net = ResNet18() + y = net(torch.randn(1,3,32,32)) + print(y.size()) \ No newline at end of file