diff --git a/README.md b/README.md index 3e75cff..8ef5d7b 100644 --- a/README.md +++ b/README.md @@ -44,8 +44,9 @@ Bayesian-Torch is designed to be flexible and enables seamless extension of dete **Key features:** -* [dnn_to_bnn()](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/models/dnn_to_bnn.py#L127): An API to convert deterministic deep neural network (dnn) model of any architecture to Bayesian deep neural network (bnn) model, simplifying the model definition i.e. drop-in replacements of Convolutional, Linear and LSTM layers to corresponding Bayesian layers. This will enable seamless conversion of existing topology of larger models to Bayesian deep neural network models for extending towards uncertainty-aware applications. +* [dnn_to_bnn()](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/models/dnn_to_bnn.py#L127): Seamless conversion of model to be Uncertainty-aware. An API to convert deterministic deep neural network (dnn) model of any architecture to Bayesian deep neural network (bnn) model, simplifying the model definition i.e. drop-in replacements of Convolutional, Linear and LSTM layers to corresponding Bayesian layers. This will enable seamless conversion of existing topology of larger models to Bayesian deep neural network models for extending towards uncertainty-aware applications. * [MOPED](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/utils/util.py#L72): Specifying weight priors and variational posteriors in Bayesian neural networks with Empirical Bayes [[Krishnan et al. 2020](https://ojs.aaai.org/index.php/AAAI/article/view/5875)] +* [Quantization](https://github.com/IntelLabs/bayesian-torch/tree/main/bayesian_torch/ao): Post Training Quantization of Bayesian deep neural network models with simple API's [enable_prepare()](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/ao/quantization/quantize.py#L134) and [convert()](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/ao/quantization/quantize.py#L160) * [AvUC](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/utils/avuc_loss.py): Accuracy versus Uncertainty Calibration loss [[Krishnan and Tickoo 2020](https://proceedings.neurips.cc/paper/2020/file/d3d9446802a44259755d38e6d163e820-Paper.pdf)] ## Installing Bayesian-Torch @@ -198,6 +199,13 @@ To evaluate deterministic ResNet on CIFAR10, run this command: sh scripts/test_deterministic_cifar.sh ``` +### Post Training Quantization (PTQ) + +To quantize Bayesian ResNet (convert to INT8) and evaluate on CIFAR10, run this command: +```test +sh scripts/quantize_bayesian_cifar.sh +``` + ## Citing If you use this code, please cite as: diff --git a/bayesian_torch/ao/quantization/quantize.py b/bayesian_torch/ao/quantization/quantize.py index 06fa99f..e7f2b64 100644 --- a/bayesian_torch/ao/quantization/quantize.py +++ b/bayesian_torch/ao/quantization/quantize.py @@ -152,7 +152,7 @@ def prepare(model): qmodel.load_state_dict(model.state_dict()) qmodel.eval() enable_prepare(qmodel) - qmodel.qconfig = torch.quantization.get_default_qconfig("fbgemm") + qmodel.qconfig = torch.quantization.get_default_qconfig("onednn") qmodel = torch.quantization.prepare(qmodel) return qmodel @@ -160,4 +160,4 @@ def prepare(model): def convert(model): qmodel = torch.quantization.convert(model) # torch layers bnn_to_qbnn(qmodel) # bayesian layers - return qmodel \ No newline at end of file + return qmodel diff --git a/bayesian_torch/examples/main_bayesian_cifar_dnn2bnn.py b/bayesian_torch/examples/main_bayesian_cifar_dnn2bnn.py index 8305844..884b28d 100644 --- a/bayesian_torch/examples/main_bayesian_cifar_dnn2bnn.py +++ b/bayesian_torch/examples/main_bayesian_cifar_dnn2bnn.py @@ -2,7 +2,7 @@ import os import shutil import time - +import random import torch import torch.nn as nn import torch.nn.parallel @@ -10,6 +10,7 @@ import torch.optim import torch.utils.data from torch.utils.tensorboard import SummaryWriter +from torch.utils.data.sampler import SubsetRandomSampler import torchvision.transforms as transforms import torchvision.datasets as datasets @@ -17,6 +18,9 @@ import numpy as np from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss +from bayesian_torch.ao.quantization.quantize import enable_prepare, convert +from bayesian_torch.models.bnn_to_qbnn import bnn_to_qbnn + model_names = sorted( name for name in resnet.__dict__ @@ -59,6 +63,13 @@ default="./checkpoint/bayesian", type=str, ) +parser.add_argument( + "--model-checkpoint", + dest="model_checkpoint", + help="Saved checkpoint for evaluating model", + default="", + type=str, +) parser.add_argument( "--moped-init-model", dest="moped_init_model", @@ -97,7 +108,7 @@ type=int, default=10, ) -parser.add_argument("--mode", type=str, required=True, help="train | test") +parser.add_argument("--mode", type=str, required=True, help="train | test | ptq | test_ptq") parser.add_argument( "--num_monte_carlo", @@ -221,6 +232,25 @@ def main(): pin_memory=True, ) + calib_loader = torch.utils.data.DataLoader( + datasets.CIFAR10( + root="./data", + train=True, + transform=transforms.Compose( + [ + transforms.ToTensor(), + normalize, + ] + ), + download=True, + ), + batch_size=args.batch_size, + sampler=SubsetRandomSampler(random.sample(range(1, 50000), 100)), + num_workers=args.workers, + pin_memory=True, + ) + + if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) @@ -286,6 +316,57 @@ def main(): model.load_state_dict(checkpoint["state_dict"]) evaluate(args, model, val_loader) + elif args.mode == "ptq": + if len(args.model_checkpoint) > 0: + checkpoint_file = args.model_checkpoint + else: + print("please provide valid model-checkpoint") + checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu")) + + ''' + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + print('load checkpoint...') + ''' + model.load_state_dict(checkpoint['state_dict']) + + + # post-training quantization + model_int8 = quantize(model, calib_loader, args) + model_int8.eval() + model_int8.cpu() + + for i, (data, target) in enumerate(calib_loader): + data = data.cpu() + + with torch.no_grad(): + traced_model = torch.jit.trace(model_int8, data) + traced_model = torch.jit.freeze(traced_model) + + save_path = os.path.join( + args.save_dir, + 'quantized_bayesian_{}_cifar.pth'.format(args.arch)) + traced_model.save(save_path) + print('INT8 model checkpoint saved at ', save_path) + print('Evaluating quantized INT8 model....') + evaluate(args, traced_model, val_loader) + + elif args.mode =='test_ptq': + print('load model...') + if len(args.model_checkpoint) > 0: + checkpoint_file = args.model_checkpoint + else: + print("please provide valid quantized model checkpoint") + model_int8 = torch.jit.load(checkpoint_file) + model_int8.eval() + model_int8.cpu() + model_int8 = torch.jit.freeze(model_int8) + print('Evaluating the INT8 model....') + evaluate(args, model_int8, val_loader) + def train(args, train_loader, model, criterion, optimizer, epoch, tb_writer=None): batch_time = AverageMeter() @@ -482,6 +563,21 @@ def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): """ torch.save(state, filename) +def quantize(model, calib_loader, args, **kwargs): + model.eval() + model.cpu() + model.qconfig = torch.quantization.get_default_qconfig("onednn") + print('Preparing model for quantization....') + enable_prepare(model) + prepared_model = torch.quantization.prepare(model) + print('Calibrating...') + with torch.no_grad(): + for batch_idx, (data, target) in enumerate(calib_loader): + data = data.cpu() + _ = prepared_model(data) + print('Calibration complete....') + quantized_model = convert(prepared_model) + return quantized_model class AverageMeter(object): """Computes and stores the average and current value""" diff --git a/bayesian_torch/models/deterministic/resnet.py b/bayesian_torch/models/deterministic/resnet.py index 62f89f5..868846f 100644 --- a/bayesian_torch/models/deterministic/resnet.py +++ b/bayesian_torch/models/deterministic/resnet.py @@ -43,6 +43,7 @@ def __init__(self, in_planes, planes, stride=1, option='A'): padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, @@ -50,12 +51,14 @@ def __init__(self, in_planes, planes, stride=1, option='A'): padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) - + self.skip_add = nn.quantized.FloatFunctional() self.shortcut = nn.Sequential() + self.relu2 = nn.ReLU(inplace=True) + if stride != 1 or in_planes != planes: if option == 'A': self.shortcut = LambdaLayer(lambda x: F.pad( - x[:, :, ::2, ::2], + x[:, :, ::2, ::2].contiguous(), (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0)) elif option == 'B': self.shortcut = nn.Sequential( @@ -67,12 +70,18 @@ def __init__(self, in_planes, planes, stride=1, option='A'): nn.BatchNorm2d(self.expansion * planes)) def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - out += self.shortcut(x) - out = F.relu(out) + identity = self.shortcut(x) + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.skip_add.add(out, identity) + #out += self.shortcut(x) + out = self.relu2(out) return out - + class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): @@ -86,6 +95,7 @@ def __init__(self, block, num_blocks, num_classes=10): padding=1, bias=False) self.bn1 = nn.BatchNorm2d(16) + self.relu1 = nn.ReLU(inplace=True) self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) @@ -103,7 +113,9 @@ def _make_layer(self, block, planes, num_blocks, stride): return nn.Sequential(*layers) def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) @@ -112,7 +124,6 @@ def forward(self, x): out = self.linear(out) return out - def resnet20(): return ResNet(BasicBlock, [3, 3, 3])