Skip to content
This repository has been archived by the owner on Dec 3, 2024. It is now read-only.

Commit

Permalink
include example for performing post training quantization of Bayesian
Browse files Browse the repository at this point in the history
neural network models using Bayesian-Torch Quantization framework.

Signed-off-by: Ranganath Krishnan <[email protected]>
  • Loading branch information
ranganathkrishnan committed Jul 27, 2023
1 parent 52cea4f commit c3e9a0f
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 14 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ Bayesian-Torch is designed to be flexible and enables seamless extension of dete


**Key features:**
* [dnn_to_bnn()](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/models/dnn_to_bnn.py#L127): An API to convert deterministic deep neural network (dnn) model of any architecture to Bayesian deep neural network (bnn) model, simplifying the model definition i.e. drop-in replacements of Convolutional, Linear and LSTM layers to corresponding Bayesian layers. This will enable seamless conversion of existing topology of larger models to Bayesian deep neural network models for extending towards uncertainty-aware applications.
* [dnn_to_bnn()](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/models/dnn_to_bnn.py#L127): Seamless conversion of model to be Uncertainty-aware. An API to convert deterministic deep neural network (dnn) model of any architecture to Bayesian deep neural network (bnn) model, simplifying the model definition i.e. drop-in replacements of Convolutional, Linear and LSTM layers to corresponding Bayesian layers. This will enable seamless conversion of existing topology of larger models to Bayesian deep neural network models for extending towards uncertainty-aware applications.
* [MOPED](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/utils/util.py#L72): Specifying weight priors and variational posteriors in Bayesian neural networks with Empirical Bayes [[Krishnan et al. 2020](https://ojs.aaai.org/index.php/AAAI/article/view/5875)]
* [Quantization](https://github.com/IntelLabs/bayesian-torch/tree/main/bayesian_torch/ao): Post Training Quantization of Bayesian deep neural network models with simple API's [enable_prepare()](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/ao/quantization/quantize.py#L134) and [convert()](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/ao/quantization/quantize.py#L160)
* [AvUC](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/utils/avuc_loss.py): Accuracy versus Uncertainty Calibration loss [[Krishnan and Tickoo 2020](https://proceedings.neurips.cc/paper/2020/file/d3d9446802a44259755d38e6d163e820-Paper.pdf)]

## Installing Bayesian-Torch
Expand Down Expand Up @@ -198,6 +199,13 @@ To evaluate deterministic ResNet on CIFAR10, run this command:
sh scripts/test_deterministic_cifar.sh
```

### Post Training Quantization (PTQ)

To quantize Bayesian ResNet (convert to INT8) and evaluate on CIFAR10, run this command:
```test
sh scripts/quantize_bayesian_cifar.sh
```

## Citing

If you use this code, please cite as:
Expand Down
4 changes: 2 additions & 2 deletions bayesian_torch/ao/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,12 @@ def prepare(model):
qmodel.load_state_dict(model.state_dict())
qmodel.eval()
enable_prepare(qmodel)
qmodel.qconfig = torch.quantization.get_default_qconfig("fbgemm")
qmodel.qconfig = torch.quantization.get_default_qconfig("onednn")
qmodel = torch.quantization.prepare(qmodel)

return qmodel

def convert(model):
qmodel = torch.quantization.convert(model) # torch layers
bnn_to_qbnn(qmodel) # bayesian layers
return qmodel
return qmodel
100 changes: 98 additions & 2 deletions bayesian_torch/examples/main_bayesian_cifar_dnn2bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,25 @@
import os
import shutil
import time

import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import bayesian_torch.models.deterministic.resnet as resnet
import numpy as np
from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss

from bayesian_torch.ao.quantization.quantize import enable_prepare, convert
from bayesian_torch.models.bnn_to_qbnn import bnn_to_qbnn

model_names = sorted(
name
for name in resnet.__dict__
Expand Down Expand Up @@ -59,6 +63,13 @@
default="./checkpoint/bayesian",
type=str,
)
parser.add_argument(
"--model-checkpoint",
dest="model_checkpoint",
help="Saved checkpoint for evaluating model",
default="",
type=str,
)
parser.add_argument(
"--moped-init-model",
dest="moped_init_model",
Expand Down Expand Up @@ -97,7 +108,7 @@
type=int,
default=10,
)
parser.add_argument("--mode", type=str, required=True, help="train | test")
parser.add_argument("--mode", type=str, required=True, help="train | test | ptq | test_ptq")

parser.add_argument(
"--num_monte_carlo",
Expand Down Expand Up @@ -221,6 +232,25 @@ def main():
pin_memory=True,
)

calib_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(
root="./data",
train=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
normalize,
]
),
download=True,
),
batch_size=args.batch_size,
sampler=SubsetRandomSampler(random.sample(range(1, 50000), 100)),
num_workers=args.workers,
pin_memory=True,
)


if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)

Expand Down Expand Up @@ -286,6 +316,57 @@ def main():
model.load_state_dict(checkpoint["state_dict"])
evaluate(args, model, val_loader)

elif args.mode == "ptq":
if len(args.model_checkpoint) > 0:
checkpoint_file = args.model_checkpoint
else:
print("please provide valid model-checkpoint")
checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))

'''
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
print('load checkpoint...')
'''
model.load_state_dict(checkpoint['state_dict'])


# post-training quantization
model_int8 = quantize(model, calib_loader, args)
model_int8.eval()
model_int8.cpu()

for i, (data, target) in enumerate(calib_loader):
data = data.cpu()

with torch.no_grad():
traced_model = torch.jit.trace(model_int8, data)
traced_model = torch.jit.freeze(traced_model)

save_path = os.path.join(
args.save_dir,
'quantized_bayesian_{}_cifar.pth'.format(args.arch))
traced_model.save(save_path)
print('INT8 model checkpoint saved at ', save_path)
print('Evaluating quantized INT8 model....')
evaluate(args, traced_model, val_loader)

elif args.mode =='test_ptq':
print('load model...')
if len(args.model_checkpoint) > 0:
checkpoint_file = args.model_checkpoint
else:
print("please provide valid quantized model checkpoint")
model_int8 = torch.jit.load(checkpoint_file)
model_int8.eval()
model_int8.cpu()
model_int8 = torch.jit.freeze(model_int8)
print('Evaluating the INT8 model....')
evaluate(args, model_int8, val_loader)


def train(args, train_loader, model, criterion, optimizer, epoch, tb_writer=None):
batch_time = AverageMeter()
Expand Down Expand Up @@ -482,6 +563,21 @@ def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
"""
torch.save(state, filename)

def quantize(model, calib_loader, args, **kwargs):
model.eval()
model.cpu()
model.qconfig = torch.quantization.get_default_qconfig("onednn")
print('Preparing model for quantization....')
enable_prepare(model)
prepared_model = torch.quantization.prepare(model)
print('Calibrating...')
with torch.no_grad():
for batch_idx, (data, target) in enumerate(calib_loader):
data = data.cpu()
_ = prepared_model(data)
print('Calibration complete....')
quantized_model = convert(prepared_model)
return quantized_model

class AverageMeter(object):
"""Computes and stores the average and current value"""
Expand Down
29 changes: 20 additions & 9 deletions bayesian_torch/models/deterministic/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,22 @@ def __init__(self, in_planes, planes, stride=1, option='A'):
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes,
planes,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(planes)

self.skip_add = nn.quantized.FloatFunctional()
self.shortcut = nn.Sequential()
self.relu2 = nn.ReLU(inplace=True)

if stride != 1 or in_planes != planes:
if option == 'A':
self.shortcut = LambdaLayer(lambda x: F.pad(
x[:, :, ::2, ::2],
x[:, :, ::2, ::2].contiguous(),
(0, 0, 0, 0, planes // 4, planes // 4), "constant", 0))
elif option == 'B':
self.shortcut = nn.Sequential(
Expand All @@ -67,12 +70,18 @@ def __init__(self, in_planes, planes, stride=1, option='A'):
nn.BatchNorm2d(self.expansion * planes))

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
identity = self.shortcut(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.skip_add.add(out, identity)
#out += self.shortcut(x)
out = self.relu2(out)
return out


class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
Expand All @@ -86,6 +95,7 @@ def __init__(self, block, num_blocks, num_classes=10):
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.relu1 = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
Expand All @@ -103,7 +113,9 @@ def _make_layer(self, block, planes, num_blocks, stride):
return nn.Sequential(*layers)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
Expand All @@ -112,7 +124,6 @@ def forward(self, x):
out = self.linear(out)
return out


def resnet20():
return ResNet(BasicBlock, [3, 3, 3])

Expand Down

0 comments on commit c3e9a0f

Please sign in to comment.