Skip to content

A library for Bayesian neural network layers and uncertainty estimation in Deep Learning extending the core of PyTorch

License

Notifications You must be signed in to change notification settings

junliang-lin/bayesian-torch

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

54 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

A library for Bayesian neural network layers and uncertainty estimation in Deep Learning

python pytorch version license Downloads


Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable Bayesian inference in deep learning models to quantify principled uncertainty estimates in model predictions.

Overview

Bayesian-Torch is designed to be flexible and enables seamless extension of deterministic deep neural network model to corresponding Bayesian form by simply replacing the deterministic layers with Bayesian layers. It enables user to perform stochastic variational inference in deep neural networks.

Bayesian layers:

Key features:

  • dnn_to_bnn(): An API to convert deterministic deep neural network (dnn) model of any architecture to Bayesian deep neural network (bnn) model, simplifying the model definition i.e. drop-in replacements of Convolutional, Linear and LSTM layers to corresponding Bayesian layers. This will enable seamless conversion of existing topology of larger models to Bayesian deep neural network models for extending towards uncertainty-aware applications.
  • MOPED: Specifying weight priors and variational posteriors in Bayesian neural networks with Empirical Bayes [Krishnan et al. 2020]
  • AvUC: Accuracy versus Uncertainty Calibration loss [Krishnan and Tickoo 2020]

Installing Bayesian-Torch

To install core library using pip:

pip install bayesian-torch

To install latest development version from source:

git clone https://github.com/IntelLabs/bayesian-torch
cd bayesian-torch
pip install .

Usage

There are two ways to build Bayesian deep neural networks using Bayesian-Torch:

  1. Convert an existing deterministic deep neural network (dnn) model to Bayesian deep neural network (bnn) model with dnn_to_bnn() API
  2. Define your custom model using the Bayesian layers (Reparameterization or Flipout)

(1) For instance, building Bayesian-ResNet18 from torchvision deterministic ResNet18 model is as simple as:

import torch
import torchvision
from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss

const_bnn_prior_parameters = {
        "prior_mu": 0.0,
        "prior_sigma": 1.0,
        "posterior_mu_init": 0.0,
        "posterior_rho_init": -3.0,
        "type": "Reparameterization",  # Flipout or Reparameterization
        "moped_enable": False,  # True to initialize mu/sigma from the pretrained dnn weights
        "moped_delta": 0.5,
}
    
model = torchvision.models.resnet18()
dnn_to_bnn(model, const_bnn_prior_parameters)

To use MOPED method i.e. setting the prior and initializing variational parameters from a pretrained deterministic model (helps training convergence of larger models):

const_bnn_prior_parameters = {
        "prior_mu": 0.0,
        "prior_sigma": 1.0,
        "posterior_mu_init": 0.0,
        "posterior_rho_init": -3.0,
        "type": "Reparameterization",  # Flipout or Reparameterization
        "moped_enable": True,  # True to initialize mu/sigma from the pretrained dnn weights
        "moped_delta": 0.5,
}
    
model = torchvision.models.resnet18(pretrained=True)
dnn_to_bnn(model, const_bnn_prior_parameters)

Training snippet:

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)

output = model(x_train)
kl = get_kl_loss(model)
ce_loss = criterion(output, y_train)
loss = ce_loss + kl / args.batch_size 

loss.backward()
optimizer.step()

Testing snippet:

model.eval()
with torch.no_grad():
    output_mc = []
    for mc_run in range(args.num_monte_carlo):
        logits = model(x_test)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        output_mc.append(probs)
    output = torch.stack(output_mc)  
    pred_mean = output.mean(dim=0)
    y_pred = torch.argmax(pred_mean, axis=-1)
    test_acc = (y_pred.data.cpu().numpy() == y_test.data.cpu().numpy()).mean()

Uncertainty Quantification:

from utils.util import predictive_entropy, mutual_information

predictive_uncertainty = predictive_entropy(output.data.cpu().numpy())
model_uncertainty = mutual_information(output.data.cpu().numpy())

(2) For building custom models, we have provided example model implementations using the Bayesian layers.

Example usage (training and evaluation of models)

We have provided example usages and scripts to train/evaluate the models. The instructions for CIFAR10 examples is provided below, similar scripts for ImageNet and MNIST are available.

cd bayesian_torch

Training

To train Bayesian ResNet on CIFAR10, run this command:

Mean-field variational inference (Reparameterized Monte Carlo estimator)

sh scripts/train_bayesian_cifar.sh

Mean-field variational inference (Flipout Monte Carlo estimator)

sh scripts/train_bayesian_flipout_cifar.sh

To train deterministic ResNet on CIFAR10, run this command:

Vanilla

sh scripts/train_deterministic_cifar.sh

Evaluation

To evaluate Bayesian ResNet on CIFAR10, run this command:

Mean-field variational inference (Reparameterized Monte Carlo estimator)

sh scripts/test_bayesian_cifar.sh

Mean-field variational inference (Flipout Monte Carlo estimator)

sh scripts/test_bayesian_flipout_cifar.sh

To evaluate deterministic ResNet on CIFAR10, run this command:

Vanilla

sh scripts/test_deterministic_cifar.sh

Citing

If you use this code, please cite as:

@software{krishnan2022bayesiantorch,
  author       = {Ranganath Krishnan and Pi Esposito and Mahesh Subedar},               
  title        = {Bayesian-Torch: Bayesian neural network layers for uncertainty estimation},
  month        = jan,
  year         = 2022,
  doi          = {10.5281/zenodo.5908307},
  url          = {https://doi.org/10.5281/zenodo.5908307}
  howpublished = {\url{https://github.com/IntelLabs/bayesian-torch}}
}

Accuracy versus Uncertainty Calibration (AvUC) loss

@inproceedings{NEURIPS2020_d3d94468,
 title = {Improving model calibration with accuracy versus uncertainty optimization},
 author = {Krishnan, Ranganath and Tickoo, Omesh},
 booktitle = {Advances in Neural Information Processing Systems},
 volume = {33},
 pages = {18237--18248},
 year = {2020},
 url = {https://proceedings.neurips.cc/paper/2020/file/d3d9446802a44259755d38e6d163e820-Paper.pdf}
 
}

MOdel Priors with Empirical Bayes using DNN (MOPED)

@inproceedings{krishnan2020specifying,
  title={Specifying weight priors in bayesian deep neural networks with empirical bayes},
  author={Krishnan, Ranganath and Subedar, Mahesh and Tickoo, Omesh},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={34},
  number={04},
  pages={4477--4484},
  year={2020},
  url = {https://ojs.aaai.org/index.php/AAAI/article/view/5875}
}

This library and code is intended for researchers and developers, enables to quantify principled uncertainty estimates from deep learning model predictions using stochastic variational inference in Bayesian neural networks. Feedbacks, issues and contributions are welcome. Email to [email protected] for any questions.

About

A library for Bayesian neural network layers and uncertainty estimation in Deep Learning extending the core of PyTorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.1%
  • Shell 0.9%