diff --git a/bayesian_torch/__init__.py b/bayesian_torch/__init__.py index e69de29..da64647 100644 --- a/bayesian_torch/__init__.py +++ b/bayesian_torch/__init__.py @@ -0,0 +1 @@ +from bayesian_torch import quantization as quantization \ No newline at end of file diff --git a/bayesian_torch/ao/__init__.py b/bayesian_torch/ao/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bayesian_torch/ao/nn/__init__.py b/bayesian_torch/ao/nn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bayesian_torch/ao/nn/quantized/__init__.py b/bayesian_torch/ao/nn/quantized/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bayesian_torch/ao/nn/quantized/modules/quantize_conv_variational.py b/bayesian_torch/ao/nn/quantized/modules/quantize_conv_variational.py new file mode 100644 index 0000000..a8b25dc --- /dev/null +++ b/bayesian_torch/ao/nn/quantized/modules/quantize_conv_variational.py @@ -0,0 +1,1428 @@ +# Copyright (C) 2021 Intel Labs +# +# BSD-3-Clause License +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# @authors: Jun-Liang Lin +# +# ====================================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter +from ..base_variational_layer import BaseVariationalLayer_ +from .conv_variational import * +import math + +__all__ = [ + 'QuantizedConv1dReparameterization', + 'QuantizedConv2dReparameterization', + 'QuantizedConv3dReparameterization', + 'QuantizedConvTranspose1dReparameterization', + 'QuantizedConvTranspose2dReparameterization', + 'QuantizedConvTranspose3dReparameterization', +] + + +class QuantizedConv1dReparameterization(Conv1dReparameterization): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + + super(QuantizedConv1dReparameterization, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias + ) + + ## redundant ## + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + ## redundant ## + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): # Deprecated. Only for forward mode #1. + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + input: tensors + Input tensor. + + enable_int8_compute: bool, optional + Whether to enable int8 computation. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + if not enable_int8_compute: # Deprecated. Use this method for reducing model size only. + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + + weight = self.mu_kernel + (self.sigma_weight * self.eps_kernel.data.normal_()) + bias = None + + if self.bias: + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + + out = F.conv1d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups) + + else: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) # Calculate the new scale after multiplying two quantized tensors. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + new_scale = max(new_scale, self.quantized_mu_weight.q_scale()) # Calculate the new scale after adding two quantized tensors. + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, new_scale, 0) + bias = None + + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, default_scale, default_zero_point, torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + out = torch.nn.quantized.functional.conv1d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + + if return_kl: + return out, 0 # disable kl divergence computing + + return out + + + +class QuantizedConv2dReparameterization(Conv2dReparameterization): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + + """ + + super(QuantizedConv2dReparameterization, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias + ) + + ## redundant ## + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + ## redundant ## + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): # Deprecated. Only for forward mode #1. + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + input: tensors + Input tensor. + + enable_int8_compute: bool, optional + Whether to enable int8 computation. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + if self.dnn_to_bnn_flag: + return_kl = False + + if not enable_int8_compute: # Deprecated. Use this method for reducing model size only. + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + + weight = self.mu_kernel + (self.sigma_weight * self.eps_kernel.data.normal_()) + bias = None + + if self.bias: + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + + out = F.conv2d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups) + + else: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6. + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) # Calculate the new scale after multiplying two quantized tensors. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + new_scale = max(new_scale, self.quantized_mu_weight.q_scale()) # Calculate the new scale after adding two quantized tensors. + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, new_scale, 0) + bias = None + + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, default_scale, default_zero_point, torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + out = torch.nn.quantized.functional.conv2d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + + if return_kl: + return out, 0 # disable kl divergence computing + + return out + + +class QuantizedConv3dReparameterization(Conv3dReparameterization): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + + super(QuantizedConv3dReparameterization, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias + ) + + ## redundant ## + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + ## redundant ## + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): # Deprecated. Only for forward mode #1. + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + input: tensors + Input tensor. + + enable_int8_compute: bool, optional + Whether to enable int8 computation. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + if self.dnn_to_bnn_flag: + return_kl = False + + if not enable_int8_compute: # Deprecated. Use this method for reducing model size only. + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + + weight = self.mu_kernel + (self.sigma_weight * self.eps_kernel.data.normal_()) + bias = None + + if self.bias: + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + + out = F.conv3d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups) + + else: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6. + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) # Calculate the new scale after multiplying two quantized tensors. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + new_scale = max(new_scale, self.quantized_mu_weight.q_scale()) # Calculate the new scale after adding two quantized tensors. + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, new_scale, 0) + bias = None + + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, default_scale, default_zero_point, torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + out = torch.nn.quantized.functional.conv3d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + + if return_kl: + return out, 0 # disable kl divergence computing + + return out + +class QuantizedConvTranspose1dReparameterization(ConvTranspose1dReparameterization): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + + super(ConvTranspose1dReparameterization, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias + ) + + ## redundant ## + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + ## redundant ## + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): # Deprecated. Only for forward mode #1. + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + input: tensors + Input tensor. + + enable_int8_compute: bool, optional + Whether to enable int8 computation. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + if self.dnn_to_bnn_flag: + return_kl = False + + if not enable_int8_compute: # Deprecated. Use this method for reducing model size only. + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + + weight = self.mu_kernel + (self.sigma_weight * self.eps_kernel.data.normal_()) + bias = None + + if self.bias: + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + + out = F.conv_transpose1d(input, weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + else: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6. + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) # Calculate the new scale after multiplying two quantized tensors. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + new_scale = max(new_scale, self.quantized_mu_weight.q_scale()) # Calculate the new scale after adding two quantized tensors. + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, new_scale, 0) + bias = None + + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, default_scale, default_zero_point, torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + out = torch.ops.quantized.conv_transpose1d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + + if return_kl: + return out, 0 # disable kl divergence computing + + return out + +class QuantizedConvTranspose2dReparameterization(ConvTranspose2dReparameterization): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + + super(ConvTranspose2dReparameterization, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias + ) + + ## redundant ## + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + ## redundant ## + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): # Deprecated. Only for forward mode #1. + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + input: tensors + Input tensor. + + enable_int8_compute: bool, optional + Whether to enable int8 computation. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + if self.dnn_to_bnn_flag: + return_kl = False + + if not enable_int8_compute: # Deprecated. Use this method for reducing model size only. + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + + weight = self.mu_kernel + (self.sigma_weight * self.eps_kernel.data.normal_()) + bias = None + + if self.bias: + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + + out = F.conv_transpose2d(input, weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + else: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6. + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) # Calculate the new scale after multiplying two quantized tensors. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + new_scale = max(new_scale, self.quantized_mu_weight.q_scale()) # Calculate the new scale after adding two quantized tensors. + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, new_scale, 0) + bias = None + + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, default_scale, default_zero_point, torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + out = torch.ops.quantized.conv_transpose2d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + + if return_kl: + return out, 0 # disable kl divergence computing + + return out + +class QuantizedConvTranspose3dReparameterization(ConvTranspose3dReparameterization): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + + super(ConvTranspose3dReparameterization, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias + ) + + ## redundant ## + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + ## redundant ## + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): # Deprecated. Only for forward mode #1. + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + input: tensors + Input tensor. + + enable_int8_compute: bool, optional + Whether to enable int8 computation. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + if self.dnn_to_bnn_flag: + return_kl = False + + if not enable_int8_compute: # Deprecated. Use this method for reducing model size only. + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + + weight = self.mu_kernel + (self.sigma_weight * self.eps_kernel.data.normal_()) + bias = None + + if self.bias: + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + + out = F.conv_transpose3d(input, weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + else: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6. + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) # Calculate the new scale after multiplying two quantized tensors. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + new_scale = max(new_scale, self.quantized_mu_weight.q_scale()) # Calculate the new scale after adding two quantized tensors. + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, new_scale, 0) + bias = None + + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, default_scale, default_zero_point, torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + out = torch.ops.quantized.conv_transpose3d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + + if return_kl: + return out, 0 # disable kl divergence computing + + return out \ No newline at end of file diff --git a/bayesian_torch/ao/nn/quantized/modules/quantize_linear_variational.py b/bayesian_torch/ao/nn/quantized/modules/quantize_linear_variational.py new file mode 100644 index 0000000..e666f9b --- /dev/null +++ b/bayesian_torch/ao/nn/quantized/modules/quantize_linear_variational.py @@ -0,0 +1,204 @@ +# Copyright (C) 2021 Intel Labs +# +# BSD-3-Clause License +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ====================================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Module, Parameter +from ..base_variational_layer import BaseVariationalLayer_ +import math +from .linear_variational import LinearReparameterization + + + +class QuantizedLinearReparameterization(LinearReparameterization): + def __init__(self, + in_features, + out_features): + """ + + """ + super(QuantizedLinearReparameterization, self).__init__( + in_features, + out_features) + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_weight), requires_grad=False) + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_weight))), requires_grad=False) + delattr(self, "mu_weight") + delattr(self, "rho_weight") + + self.quantized_mu_bias = Parameter(self.get_quantized_tensor(self.mu_bias), requires_grad=False) + self.quantized_sigma_bias = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_bias))), requires_grad=False) + delattr(self, "mu_bias") + delattr(self, "rho_bias") + + def dequantize(self): # Deprecated + self.mu_weight = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + return + + def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + input: tensors + Input tensor. + + enable_int8_compute: bool, optional + Whether to enable int8 computation. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + if self.dnn_to_bnn_flag: + return_kl = False + + if not enable_int8_compute: # Deprecated. Use this method for reducing model size only. + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + weight = self.mu_weight + (self.sigma_weight * self.eps_weight.data.normal_()) + bias = None + if self.sigma_bias is not None: + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + + out = F.linear(input, weight, bias) + + else: + eps_weight = torch.quantize_per_tensor(self.eps_weight.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_weight.q_scale()) + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_weight, new_scale, 0) + new_scale = max(new_scale, self.quantized_mu_weight.q_scale()) + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, new_scale, 0) + bias = None + + if self.quantized_sigma_bias is not None: + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + if input.dtype!=torch.quint8: + input = torch.quantize_per_tensor(input, default_scale, default_zero_point, torch.quint8) + + out = torch.nn.quantized.functional.linear(input, weight, bias, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + out = out.dequantize() + + if return_kl: + return out, 0 # disable kl divergence computing + + return out diff --git a/bayesian_torch/ao/nn/quantized/modules/quantized_conv_flipout.py b/bayesian_torch/ao/nn/quantized/modules/quantized_conv_flipout.py new file mode 100644 index 0000000..cf771c7 --- /dev/null +++ b/bayesian_torch/ao/nn/quantized/modules/quantized_conv_flipout.py @@ -0,0 +1,1303 @@ +# Copyright (C) 2021 Intel Labs +# +# BSD-3-Clause License +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# +# Convolutional layers with flipout Monte Carlo weight estimator to perform +# variational inference in Bayesian neural networks. Variational layers +# enables Monte Carlo approximation of the distribution over the kernel +# +# +# ====================================================================================== +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter +from ..base_variational_layer import BaseVariationalLayer_ +from .conv_flipout import * + +from torch.distributions.normal import Normal +from torch.distributions.uniform import Uniform + +__all__ = [ + 'QuantizedConv1dFlipout', + 'QuantizedConv2dFlipout', + 'QuantizedConv3dFlipout', + 'QuantizedConvTranspose1dFlipout', + 'QuantizedConvTranspose2dFlipout', + 'QuantizedConvTranspose3dFlipout', +] + + +class QuantizedConv1dFlipout(Conv1dFlipout): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + super(QuantizedConv1dFlipout).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias) + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format. + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + x: tensors + Input tensor. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + if x.dtype!=torch.quint8: + x = torch.quantize_per_tensor(x, default_scale, default_zero_point, torch.quint8) + + bias = None + if self.bias: + bias = self.quantized_mu_bias + + outputs = torch.nn.quantized.functional.conv1d(x, self.quantized_mu_weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + + # sampling perturbation signs + sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8) + + # getting perturbation weights + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) + delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + + bias = None + if self.bias: + eps_bias = self.eps_bias.data.normal_() + bias = (self.quantized_sigma_bias * eps_bias) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point) + + perturbed_outputs = torch.nn.quantized.functional.conv1d(x, + weight=delta_kernel, bias=bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups, scale=default_scale, zero_point=default_zero_point) + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point) + out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point) + + if return_kl: + return out, 0 + + return out + + +class QuantizedConv2dFlipout(Conv2dFlipout): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): # be aware of bias + """ + + """ + super(QuantizedConv2dFlipout, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias) + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format. + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + x: tensors + Input tensor. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + if x.dtype!=torch.quint8: + x = torch.quantize_per_tensor(x, default_scale, default_zero_point, torch.quint8) + + bias = None + if self.bias: + bias = self.quantized_mu_bias + + outputs = torch.nn.quantized.functional.conv2d(x, self.quantized_mu_weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + + # sampling perturbation signs + sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8) + + # getting perturbation weights + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) + delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + + bias = None + if self.bias: + eps_bias = self.eps_bias.data.normal_() + bias = (self.quantized_sigma_bias * eps_bias) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point) + + perturbed_outputs = torch.nn.quantized.functional.conv2d(x, + weight=delta_kernel, bias=bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups, scale=default_scale, zero_point=default_zero_point) + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point) + out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point) + + if return_kl: + return out, 0 + + return out + + +class QuantizedConv3dFlipout(Conv3dFlipout): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + super(QuantizedConv3dFlipout).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias) + + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format. + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + x: tensors + Input tensor. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + if x.dtype!=torch.quint8: + x = torch.quantize_per_tensor(x, default_scale, default_zero_point, torch.quint8) + + bias = None + if self.bias: + bias = self.quantized_mu_bias + + outputs = torch.nn.quantized.functional.conv3d(x, self.quantized_mu_weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + + # sampling perturbation signs + sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8) + + # getting perturbation weights + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) + delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + + bias = None + if self.bias: + eps_bias = self.eps_bias.data.normal_() + bias = (self.quantized_sigma_bias * eps_bias) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point) + + perturbed_outputs = torch.nn.quantized.functional.conv3d(x, + weight=delta_kernel, bias=bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups, scale=default_scale, zero_point=default_zero_point) + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point) + out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point) + + if return_kl: + return out, 0 + + return out + +class QuantizedConvTranspose1dFlipout(ConvTranspose1dFlipout): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + super(QuantizedConvTranspose1dFlipout).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias) + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + if not hasattr(self, "output_padding"): + self.output_padding = 0 + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format. + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + x: tensors + Input tensor. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + if x.dtype!=torch.quint8: + x = torch.quantize_per_tensor(x, default_scale, default_zero_point, torch.quint8) + + bias = None + if self.bias: + bias = self.quantized_mu_bias + + self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(self.quantized_mu_weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + outputs = torch.ops.quantized.conv_transpose1d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + # sampling perturbation signs + sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8) + + # getting perturbation weights + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) + delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + + bias = None + if self.bias: + eps_bias = self.eps_bias.data.normal_() + bias = (self.quantized_sigma_bias * eps_bias) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point) + + self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(delta_kernel, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + perturbed_outputs = torch.ops.quantized.conv_transpose1d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point) + out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point) + + if return_kl: + return out, 0 + + return out + +class QuantizedConvTranspose2dFlipout(ConvTranspose2dFlipout): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + super(QuantizedConvTranspose2dFlipout).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias) + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + if not hasattr(self, "output_padding"): + self.output_padding = 0 + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format. + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + x: tensors + Input tensor. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + if x.dtype!=torch.quint8: + x = torch.quantize_per_tensor(x, default_scale, default_zero_point, torch.quint8) + + bias = None + if self.bias: + bias = self.quantized_mu_bias + + self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(self.quantized_mu_weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + outputs = torch.ops.quantized.conv_transpose2d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + # sampling perturbation signs + sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8) + + # getting perturbation weights + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) + delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + + bias = None + if self.bias: + eps_bias = self.eps_bias.data.normal_() + bias = (self.quantized_sigma_bias * eps_bias) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point) + + self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(delta_kernel, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + perturbed_outputs = torch.ops.quantized.conv_transpose2d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point) + out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point) + + if return_kl: + return out, 0 + + return out + +class QuantizedConvTranspose3dFlipout(ConvTranspose3dFlipout): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + super(QuantizedConvTranspose3dFlipout).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias) + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + if not hasattr(self, "output_padding"): + self.output_padding = 0 + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format. + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + x: tensors + Input tensor. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + if x.dtype!=torch.quint8: + x = torch.quantize_per_tensor(x, default_scale, default_zero_point, torch.quint8) + + bias = None + if self.bias: + bias = self.quantized_mu_bias + + self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(self.quantized_mu_weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + outputs = torch.ops.quantized.conv_transpose3d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + # sampling perturbation signs + sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8) + + # getting perturbation weights + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) + delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + + bias = None + if self.bias: + eps_bias = self.eps_bias.data.normal_() + bias = (self.quantized_sigma_bias * eps_bias) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point) + + self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(delta_kernel, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + perturbed_outputs = torch.ops.quantized.conv_transpose3d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point) + out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point) + + if return_kl: + return out, 0 + + return out \ No newline at end of file diff --git a/bayesian_torch/ao/nn/quantized/modules/quantized_linear_flipout.py b/bayesian_torch/ao/nn/quantized/modules/quantized_linear_flipout.py new file mode 100644 index 0000000..289da98 --- /dev/null +++ b/bayesian_torch/ao/nn/quantized/modules/quantized_linear_flipout.py @@ -0,0 +1,206 @@ +# Copyright (C) 2021 Intel Labs +# +# BSD-3-Clause License +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# +# Linear Flipout Layers with flipout weight estimator to perform +# variational inference in Bayesian neural networks. Variational layers +# enables Monte Carlo approximation of the distribution over the weights +# +# @authors: Jun-Liang Lin +# +# ====================================================================================== +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Module, Parameter +from torch.distributions.normal import Normal +from torch.distributions.uniform import Uniform + +from .linear_flipout import LinearFlipout + +__all__ = ["QuantizedLinearFlipout"] + +class QuantizedLinearFlipout(LinearFlipout): + def __init__(self, + in_features, + out_features): + + super(QuantizedLinearFlipout, self).__init__( + in_features, + out_features) + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_weight), requires_grad=False) + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_weight))), requires_grad=False) + delattr(self, "mu_weight") + delattr(self, "rho_weight") + + self.quantized_mu_bias = Parameter(self.get_quantized_tensor(self.mu_bias), requires_grad=False) + self.quantized_sigma_bias = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_bias))), requires_grad=False) + delattr(self, "mu_bias") + delattr(self, "rho_bias") + + def dequantize(self): + self.mu_weight = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + return + + def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + x: tensors + Input tensor. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. Already dequantized. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + bias = None + if self.quantized_mu_bias is not None: + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + bias = self.mu_bias + + outputs = torch.nn.quantized.functional.linear(x, self.quantized_mu_weight, bias, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + + # sampling perturbation signs + sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8) + + # getting perturbation weights + eps_weight = torch.quantize_per_tensor(self.eps_weight.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_weight.q_scale()) + delta_weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_weight, new_scale, 0) + + bias = None + if self.quantized_sigma_bias is not None: + eps_bias = self.eps_bias.data.normal_() + bias = (self.sigma_bias * eps_bias) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point) + + perturbed_outputs = torch.nn.quantized.functional.linear(x, + weight=delta_weight, bias=bias, scale=default_scale, zero_point=default_zero_point) + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point) + out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point) + out = out.dequantize() + + if return_kl: + return out, 0 + + return out diff --git a/bayesian_torch/ao/quantization/__init__.py b/bayesian_torch/ao/quantization/__init__.py new file mode 100644 index 0000000..dab2378 --- /dev/null +++ b/bayesian_torch/ao/quantization/__init__.py @@ -0,0 +1,3 @@ +## bayesian_torch.quantization.prepare +## bayesian_torch.quantization.convert +from .quantize import * \ No newline at end of file diff --git a/bayesian_torch/ao/quantization/quantize.py b/bayesian_torch/ao/quantization/quantize.py new file mode 100644 index 0000000..06fa99f --- /dev/null +++ b/bayesian_torch/ao/quantization/quantize.py @@ -0,0 +1,163 @@ +# Copyright (C) 2021 Intel Labs +# +# BSD-3-Clause License +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Define prepare and convert function +# + +import torch +import torch.nn as nn +from bayesian_torch.models.bayesian.resnet_variational_large import ( + BasicBlock, + Bottleneck, + ResNet, +) +from typing import Any, List, Optional, Type, Union +from torch import Tensor +from bayesian_torch.models.bnn_to_qbnn import bnn_to_qbnn +# import copy + +__all__ = [ + "prepare", + "convert", +] + +class QuantizableBasicBlock(BasicBlock): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.add_relu = torch.nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = self.add_relu.add_relu(out, identity) + + return out + + +class QuantizableBottleneck(Bottleneck): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.skip_add_relu = nn.quantized.FloatFunctional() + self.relu1 = nn.ReLU(inplace=False) + self.relu2 = nn.ReLU(inplace=False) + + def forward(self, x: Tensor) -> Tensor: + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu2(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + out = self.skip_add_relu.add_relu(out, identity) + + return out + + +class QuantizableResNet(ResNet): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x: Tensor) -> Tensor: + x = self.quant(x) + + x= self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + for layer in self.layer1: + x=layer(x) + + for layer in self.layer2: + x = layer(x) + + for layer in self.layer3: + x = layer(x) + + for layer in self.layer4: + x = layer(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + + # x = self.dequant(x) + return x + + + +def enable_prepare(m): + for name, value in list(m._modules.items()): + if m._modules[name]._modules: + enable_prepare(m._modules[name]) + elif "Reparameterization" in m._modules[name].__class__.__name__ or "Flipout" in m._modules[name].__class__.__name__: + prepare = getattr(m._modules[name], "prepare", None) + if callable(prepare): + m._modules[name].prepare() + m._modules[name].dnn_to_bnn_flag=True + + +def prepare(model): + """ + 1. construct quantizable model + 2. traverse the model to enable the prepare function in each layer + 3. run torch.quantize.prepare() + """ + qmodel = QuantizableResNet(QuantizableBottleneck, [3, 4, 6, 3]) + qmodel.load_state_dict(model.state_dict()) + qmodel.eval() + enable_prepare(qmodel) + qmodel.qconfig = torch.quantization.get_default_qconfig("fbgemm") + qmodel = torch.quantization.prepare(qmodel) + + return qmodel + +def convert(model): + qmodel = torch.quantization.convert(model) # torch layers + bnn_to_qbnn(qmodel) # bayesian layers + return qmodel \ No newline at end of file diff --git a/bayesian_torch/examples/main_bayesian_imagenet_bnn2qbnn.py b/bayesian_torch/examples/main_bayesian_imagenet_bnn2qbnn.py new file mode 100644 index 0000000..73dea9b --- /dev/null +++ b/bayesian_torch/examples/main_bayesian_imagenet_bnn2qbnn.py @@ -0,0 +1,319 @@ +import argparse +import os +import shutil +import time + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +from torch.utils.tensorboard import SummaryWriter +import torchvision +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +import bayesian_torch +import bayesian_torch.models.bayesian.resnet_variational_large as resnet +import numpy as np +from bayesian_torch.models.bnn_to_qbnn import bnn_to_qbnn +from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn +# import bayesian_torch.models.bayesian.quantized_resnet_variational_large as qresnet +import bayesian_torch.models.bayesian.quantized_resnet_flipout_large as qresnet + +torch.cuda.is_available = lambda : False +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +torch.backends.quantized.engine='onednn' +model_names = sorted( + name + for name in resnet.__dict__ + if name.islower() and not name.startswith("__") and name.startswith("resnet") and callable(resnet.__dict__[name]) +) + +print(model_names) +best_acc1 = 0 +len_trainset = 1281167 +len_valset = 50000 + + +parser = argparse.ArgumentParser(description="ImageNet") +parser.add_argument('data', + metavar='DIR', + default='data/imagenet', + help='path to dataset') +parser.add_argument( + "--arch", + "-a", + metavar="ARCH", + default="resnet50", + choices=model_names, + help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", +) +parser.add_argument( + "-j", "--workers", default=8, type=int, metavar="N", help="number of data loading workers (default: 8)" +) +parser.add_argument("--epochs", default=200, type=int, metavar="N", help="number of total epochs to run") +parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="manual epoch number (useful on restarts)") +parser.add_argument("-b", "--batch-size", default=1000, type=int, metavar="N", help="mini-batch size (default: 512)") +parser.add_argument('--val_batch_size', default=1000, type=int) +parser.add_argument("--lr", "--learning-rate", default=0.001, type=float, metavar="LR", help="initial learning rate") +parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") +parser.add_argument( + "--weight-decay", "--wd", default=1e-4, type=float, metavar="W", help="weight decay (default: 5e-4)" +) +parser.add_argument("--print-freq", "-p", default=50, type=int, metavar="N", help="print frequency (default: 20)") +parser.add_argument("--resume", default="", type=str, metavar="PATH", help="path to latest checkpoint (default: none)") +parser.add_argument("-e", "--evaluate", dest="evaluate", action="store_true", help="evaluate model on validation set") +parser.add_argument("--pretrained", dest="pretrained", action="store_true", help="use pre-trained model") +parser.add_argument("--half", dest="half", action="store_true", help="use half-precision(16-bit) ") +parser.add_argument( + "--save-dir", + dest="save_dir", + help="The directory used to save the trained models", + default="../../bayesian-torch-20221214/bayesian_torch/checkpoint/bayesian", + type=str, +) +parser.add_argument( + "--moped-init-model", + dest="moped_init_model", + help="DNN model to intialize MOPED method", + default="", + type=str, +) +parser.add_argument( + "--moped-delta-factor", + dest="moped_delta_factor", + help="MOPED delta scale factor", + default=0.2, + type=float, +) + +parser.add_argument( + "--bnn-rho-init", + dest="bnn_rho_init", + help="rho init for bnn layers", + default=-3.0, + type=float, +) + +parser.add_argument( + "--use-flipout-layers", + type=bool, + default=False, + metavar="use_flipout_layers", + help="Use Flipout layers for BNNs, default is Reparameterization layers", +) + +parser.add_argument( + "--save-every", + dest="save_every", + help="Saves checkpoints at every specified number of epochs", + type=int, + default=10, +) +parser.add_argument("--mode", type=str, required=True, help="train | test") + +parser.add_argument( + "--num_monte_carlo", + type=int, + default=20, + metavar="N", + help="number of Monte Carlo samples to be drawn during inference", +) +parser.add_argument("--num_mc", type=int, default=1, metavar="N", help="number of Monte Carlo runs during training") +parser.add_argument( + "--tensorboard", + type=bool, + default=True, + metavar="N", + help="use tensorboard for logging and visualization of training progress", +) +parser.add_argument( + "--log_dir", + type=str, + default="./logs/cifar/bayesian", + metavar="N", + help="use tensorboard for logging and visualization of training progress", +) + +def evaluate(args, model, val_loader, calibration=False): + pred_probs_mc = [] + test_loss = 0 + correct = 0 + output_list = [] + labels_list = [] + model.eval() + with torch.no_grad(): + begin = time.time() + i=0 + for data, target in val_loader: + if torch.cuda.is_available(): + data, target = data.cuda(), target.cuda() + else: + data, target = data.cpu(), target.cpu() + output_mc = [] + for mc_run in range(args.num_monte_carlo): + output = model.forward(data) + output_mc.append(output) + output_ = torch.stack(output_mc) + output_list.append(output_) + labels_list.append(target) + i+=1 + end = time.time() + print("inference throughput: ", i*args.val_batch_size / (end - begin), " images/s") + # break + if calibration and i==3: + break + + output = torch.cat(output_list, 1) + output = torch.nn.functional.softmax(output, dim=2) + labels = torch.cat(labels_list) + pred_mean = output.mean(dim=0) + Y_pred = torch.argmax(pred_mean, axis=1) + print("Test accuracy:", (Y_pred.data.cpu().numpy() == labels.data.cpu().numpy()).mean() * 100) + + +def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): + """ + Save the training model + """ + torch.save(state, filename) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + +best_prec1 = 0 + +def main(): + global args, best_prec1 + args = parser.parse_args() + moped_enable = False + if len(args.moped_init_model) > 0: # use moped method if trained dnn model weights are provided + moped_enable = True + + # Check the save_dir exists or not + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + + model = torch.nn.DataParallel(resnet.__dict__[args.arch]()) + if moped_enable: + checkpoint = torch.load(args.moped_init_model) + if "state_dict" in checkpoint.keys(): + model.load_state_dict(checkpoint["state_dict"]) + else: + model.load_state_dict(checkpoint) + + tb_writer = None + + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + val_dataset = datasets.ImageFolder( + valdir, + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + + val_loader = torch.utils.data.DataLoader(val_dataset, + batch_size=args.val_batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True) + + print('len valset: ', len(val_dataset)) + + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + + if args.mode == "test": + const_bnn_prior_parameters = { + "prior_mu": 0.0, + "prior_sigma": 1.0, + "posterior_mu_init": 0.0, + "posterior_rho_init": args.bnn_rho_init, + "type": "Flipout" if args.use_flipout_layers else "Reparameterization", # Flipout or Reparameterization + "moped_enable": moped_enable, # initialize mu/sigma from the dnn weights + "moped_delta": args.moped_delta_factor, + } + quantizable_model = torchvision.models.quantization.resnet50() + dnn_to_bnn(quantizable_model, const_bnn_prior_parameters) + model = torch.nn.DataParallel(quantizable_model) + + + checkpoint_file = args.save_dir + "/bayesian_{}_imagenet.pth".format(args.arch) + + checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu")) + model.load_state_dict(checkpoint["state_dict"]) + model.module = model.module.cpu() + + mp = bayesian_torch.quantization.prepare(model) + evaluate(args, mp, val_loader, calibration=True) # calibration + qmodel = bayesian_torch.quantization.convert(mp) + evaluate(args, qmodel, val_loader) + + # save weights + save_checkpoint( + { + 'epoch': None, + 'state_dict': qmodel.state_dict(), + 'best_prec1': None, + }, + True, + filename=os.path.join( + args.save_dir, + 'quantized_bayesian_{}_imagenetv2.pth'.format(args.arch))) + + # reconstruct (no calibration) + quantizable_model = torchvision.models.quantization.resnet50() + dnn_to_bnn(quantizable_model, const_bnn_prior_parameters) + model = torch.nn.DataParallel(quantizable_model) + mp = bayesian_torch.quantization.prepare(model) + qmodel1 = bayesian_torch.quantization.convert(mp) + + # load + checkpoint_file = args.save_dir + "/quantized_bayesian_{}_imagenetv2.pth".format(args.arch) + checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu")) + qmodel1.load_state_dict(checkpoint["state_dict"]) + evaluate(args, qmodel1, val_loader) + + + return mp, qmodel, qmodel1 + +if __name__ == "__main__": + mp, qmodel, qmodel1 = main() diff --git a/bayesian_torch/examples/main_bayesian_imagenet_dnn2bnn.py b/bayesian_torch/examples/main_bayesian_imagenet_dnn2bnn.py new file mode 100644 index 0000000..28e03ae --- /dev/null +++ b/bayesian_torch/examples/main_bayesian_imagenet_dnn2bnn.py @@ -0,0 +1,551 @@ +import argparse +import os +import shutil +import time + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +from torch.utils.tensorboard import SummaryWriter +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +import bayesian_torch.models.deterministic.resnet_large as resnet +import numpy as np +from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss + +model_names = sorted( + name + for name in resnet.__dict__ + if name.islower() and not name.startswith("__") and name.startswith("resnet") and callable(resnet.__dict__[name]) +) + +print(model_names) +best_acc1 = 0 +len_trainset = 1281167 +len_valset = 50000 + + +parser = argparse.ArgumentParser(description="ImageNet") +parser.add_argument('data', + metavar='DIR', + default='data/imagenet', + help='path to dataset') +parser.add_argument( + "--arch", + "-a", + metavar="ARCH", + default="resnet50", + choices=model_names, + help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", +) +parser.add_argument( + "-j", "--workers", default=8, type=int, metavar="N", help="number of data loading workers (default: 8)" +) +parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") +parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="manual epoch number (useful on restarts)") +parser.add_argument("-b", "--batch-size", default=128, type=int, metavar="N", help="mini-batch size (default: 128)") +parser.add_argument('--val_batch_size', default=1000, type=int) +parser.add_argument("--lr", "--learning-rate", default=0.001, type=float, metavar="LR", help="initial learning rate") +parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") +parser.add_argument( + "--weight-decay", "--wd", default=1e-4, type=float, metavar="W", help="weight decay (default: 5e-4)" +) +parser.add_argument("--print-freq", "-p", default=50, type=int, metavar="N", help="print frequency (default: 20)") +parser.add_argument("--resume", default="", type=str, metavar="PATH", help="path to latest checkpoint (default: none)") +parser.add_argument("-e", "--evaluate", dest="evaluate", action="store_true", help="evaluate model on validation set") +parser.add_argument("--pretrained", dest="pretrained", action="store_true", help="use pre-trained model") +parser.add_argument("--half", dest="half", action="store_true", help="use half-precision(16-bit) ") +parser.add_argument( + "--save-dir", + dest="save_dir", + help="The directory used to save the trained models", + default="./checkpoint/bayesian", + type=str, +) +parser.add_argument( + "--moped-init-model", + dest="moped_init_model", + help="DNN model to intialize MOPED method", + default="", + type=str, +) +parser.add_argument( + "--moped-delta-factor", + dest="moped_delta_factor", + help="MOPED delta scale factor", + default=0.001, + type=float, +) + +parser.add_argument( + "--bnn-rho-init", + dest="bnn_rho_init", + help="rho init for bnn layers", + default=-10.0, + type=float, +) + +parser.add_argument( + "--use-flipout-layers", + type=bool, + default=False, + metavar="use_flipout_layers", + help="Use Flipout layers for BNNs, default is Reparameterization layers", +) + +parser.add_argument( + "--save-every", + dest="save_every", + help="Saves checkpoints at every specified number of epochs", + type=int, + default=10, +) +parser.add_argument("--mode", type=str, required=True, help="train | test") + +parser.add_argument( + "--num_monte_carlo", + type=int, + default=20, + metavar="N", + help="number of Monte Carlo samples to be drawn during inference", +) +parser.add_argument("--num_mc", type=int, default=1, metavar="N", help="number of Monte Carlo runs during training") +parser.add_argument( + "--tensorboard", + type=bool, + default=True, + metavar="N", + help="use tensorboard for logging and visualization of training progress", +) +parser.add_argument( + "--log_dir", + type=str, + default="./logs/imagenet/bayesian", + metavar="N", + help="use tensorboard for logging and visualization of training progress", +) + +best_prec1 = 0 + + +def main(): + global args, best_prec1 + args = parser.parse_args() + moped_enable = False + if len(args.moped_init_model) > 0: # use moped method if trained dnn model weights are provided + moped_enable = True + + const_bnn_prior_parameters = { + "prior_mu": 0.0, + "prior_sigma": 1.0, + "posterior_mu_init": 0.0, + "posterior_rho_init": args.bnn_rho_init, + "type": "Flipout" if args.use_flipout_layers else "Reparameterization", # Flipout or Reparameterization + "moped_enable": moped_enable, # initialize mu/sigma from the dnn weights + "moped_delta": args.moped_delta_factor, + } + + # Check the save_dir exists or not + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + + model = torch.nn.DataParallel(resnet.__dict__[args.arch](pretrained=True)) + model.cuda() if torch.cuda.is_available() else model.cpu() + if moped_enable: + checkpoint = torch.load(args.moped_init_model) + if "state_dict" in checkpoint.keys(): + model.load_state_dict(checkpoint["state_dict"]) + else: + model.load_state_dict(checkpoint) + + const_bnn_prior_parameters["moped_enable"]=True + dnn_to_bnn(model, const_bnn_prior_parameters) # only replaces linear and conv layers + + save_checkpoint( + { + "epoch": 0, + "state_dict": model.state_dict(), + "best_prec1": best_prec1, + }, + False, + filename=os.path.join(args.save_dir, "bayesian_{}_imagenet.pth".format(args.arch)), + ) + + if torch.cuda.is_available(): + model.cuda() + else: + model.cpu() + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume) + args.start_epoch = checkpoint["epoch"] + best_prec1 = checkpoint["best_prec1"] + model.load_state_dict(checkpoint) + print("=> loaded checkpoint '{}' (epoch {})".format(args.evaluate, checkpoint["epoch"])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + tb_writer = None + if args.tensorboard: + logger_dir = os.path.join(args.log_dir, "tb_logger") + if not os.path.exists(logger_dir): + os.makedirs(logger_dir) + tb_writer = SummaryWriter(logger_dir) + + valdir = os.path.join(args.data, 'val') #Imagenet_2012Val + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + # train_loader = torch.utils.data.DataLoader( + # datasets.CIFAR10( + # root="./data", + # train=True, + # transform=transforms.Compose( + # [ + # transforms.RandomHorizontalFlip(), + # transforms.RandomCrop(32, 4), + # transforms.ToTensor(), + # normalize, + # ] + # ), + # download=True, + # ), + # batch_size=args.batch_size, + # shuffle=True, + # num_workers=args.workers, + # pin_memory=True, + # ) + + val_dataset = datasets.ImageFolder( + valdir, + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + val_loader = torch.utils.data.DataLoader(val_dataset, + batch_size=args.val_batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True) + + print('len valset: ', len(val_dataset)) + + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + + if torch.cuda.is_available(): + criterion = nn.CrossEntropyLoss().cuda() + else: + criterion = nn.CrossEntropyLoss().cpu() + + if args.half: + model.half() + criterion.half() + + if args.arch in ["resnet110"]: + for param_group in optimizer.param_groups: + param_group["lr"] = args.lr * 0.1 + + if args.evaluate: + validate(val_loader, model, criterion) + return + + if args.mode == "train": + pass + + for epoch in range(args.start_epoch, args.epochs): + + lr = args.lr + if epoch >= 80 and epoch < 120: + lr = 0.1 * args.lr + elif epoch >= 120 and epoch < 160: + lr = 0.01 * args.lr + elif epoch >= 160 and epoch < 180: + lr = 0.001 * args.lr + elif epoch >= 180: + lr = 0.0005 * args.lr + + optimizer = torch.optim.Adam(model.parameters(), lr) + + # train for one epoch + print("current lr {:.5e}".format(optimizer.param_groups[0]["lr"])) + train(args, train_loader, model, criterion, optimizer, epoch, tb_writer) + + prec1 = validate(args, val_loader, model, criterion, epoch, tb_writer) + + is_best = prec1 > best_prec1 + best_prec1 = max(prec1, best_prec1) + + if is_best: + save_checkpoint( + { + "epoch": epoch + 1, + "state_dict": model.state_dict(), + "best_prec1": best_prec1, + }, + is_best, + filename=os.path.join(args.save_dir, "bayesian_{}_imagenet.pth".format(args.arch)), + ) + + elif args.mode == "test": + checkpoint_file = args.save_dir + "/bayesian_{}_imagenet.pth".format(args.arch) + if torch.cuda.is_available(): + checkpoint = torch.load(checkpoint_file) + else: + checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu")) + model.load_state_dict(checkpoint["state_dict"]) + evaluate(args, model, val_loader) + + +def train(args, train_loader, model, criterion, optimizer, epoch, tb_writer=None): + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + + # switch to train mode + model.train() + + end = time.time() + for i, (input, target) in enumerate(train_loader): + + # measure data loading time + data_time.update(time.time() - end) + + if torch.cuda.is_available(): + target = target.cuda() + input_var = input.cuda() + target_var = target + else: + target = target.cpu() + input_var = input.cpu() + target_var = target + + if args.half: + input_var = input_var.half() + + # compute output + output_ = [] + kl_ = [] + for mc_run in range(args.num_mc): + output = model(input_var) + kl = get_kl_loss(model) + output_.append(output) + kl_.append(kl) + output = torch.mean(torch.stack(output_), dim=0) + kl = torch.mean(torch.stack(kl_), dim=0) + cross_entropy_loss = criterion(output, target_var) + scaled_kl = kl / args.batch_size + + # ELBO loss + loss = cross_entropy_loss + scaled_kl + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + output = output.float() + loss = loss.float() + # measure accuracy and record loss + prec1 = accuracy(output.data, target)[0] + losses.update(loss.item(), input.size(0)) + top1.update(prec1.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print( + "Epoch: [{0}][{1}/{2}]\t" + "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" + "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" + "Loss {loss.val:.4f} ({loss.avg:.4f})\t" + "Prec@1 {top1.val:.3f} ({top1.avg:.3f})".format( + epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1 + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar("train/cross_entropy_loss", cross_entropy_loss.item(), epoch) + tb_writer.add_scalar("train/kl_div", scaled_kl.item(), epoch) + tb_writer.add_scalar("train/elbo_loss", loss.item(), epoch) + tb_writer.add_scalar("train/accuracy", prec1.item(), epoch) + tb_writer.flush() + + +def validate(args, val_loader, model, criterion, epoch, tb_writer=None): + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + + # switch to evaluate mode + model.eval() + + end = time.time() + with torch.no_grad(): + for i, (input, target) in enumerate(val_loader): + if torch.cuda.is_available(): + target = target.cuda() + input_var = input.cuda() + target_var = target.cuda() + else: + target = target.cpu() + input_var = input.cpu() + target_var = target.cpu() + + if args.half: + input_var = input_var.half() + + # compute output + output_ = [] + kl_ = [] + for mc_run in range(args.num_mc): + output = model(input_var) + kl = get_kl_loss(model) + output_.append(output) + kl_.append(kl) + output = torch.mean(torch.stack(output_), dim=0) + kl = torch.mean(torch.stack(kl_), dim=0) + cross_entropy_loss = criterion(output, target_var) + # scaled_kl = kl / len_trainset + scaled_kl = kl / args.batch_size + # scaled_kl = 0.2 * (kl / len_trainset) + + # ELBO loss + loss = cross_entropy_loss + scaled_kl + + output = output.float() + loss = loss.float() + + # measure accuracy and record loss + prec1 = accuracy(output.data, target)[0] + losses.update(loss.item(), input.size(0)) + top1.update(prec1.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print( + "Test: [{0}/{1}]\t" + "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" + "Loss {loss.val:.4f} ({loss.avg:.4f})\t" + "Prec@1 {top1.val:.3f} ({top1.avg:.3f})".format( + i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1 + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar("val/cross_entropy_loss", cross_entropy_loss.item(), epoch) + tb_writer.add_scalar("val/kl_div", scaled_kl.item(), epoch) + tb_writer.add_scalar("val/elbo_loss", loss.item(), epoch) + tb_writer.add_scalar("val/accuracy", prec1.item(), epoch) + tb_writer.flush() + + print(" * Prec@1 {top1.avg:.3f}".format(top1=top1)) + + return top1.avg + + +def evaluate(args, model, val_loader): + pred_probs_mc = [] + test_loss = 0 + correct = 0 + output_list = [] + labels_list = [] + model.eval() + with torch.no_grad(): + begin = time.time() + i=0 + for data, target in val_loader: + if torch.cuda.is_available(): + data, target = data.cuda(), target.cuda() + else: + data, target = data.cpu(), target.cpu() + output_mc = [] + for mc_run in range(args.num_monte_carlo): + output = model.forward(data) + output_mc.append(output) + output_ = torch.stack(output_mc) + output_list.append(output_) + labels_list.append(target) + i+=1 + # if i==10: + # break + end = time.time() + print("inference throughput: ", 50000 / (end - begin), " images/s") + + # output = torch.stack(output_list) + # output = output.permute(1, 0, 2, 3) + # output = output.contiguous().view(args.num_monte_carlo, len_valset, -1) + output = torch.cat(output_list, 1) + output = torch.nn.functional.softmax(output, dim=2) + labels = torch.cat(labels_list) + pred_mean = output.mean(dim=0) + Y_pred = torch.argmax(pred_mean, axis=1) + + np.save("./probs_cifar_mc.npy", output.data.cpu().numpy()) + np.save("./cifar_test_labels_mc.npy", labels.data.cpu().numpy()) + print(Y_pred.shape, labels.shape) + print(Y_pred[:100], labels[:100]) + print("Test accuracy:", (Y_pred.data.cpu().numpy() == labels.data.cpu().numpy()).mean() * 100) + + +def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): + """ + Save the training model + """ + torch.save(state, filename) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == "__main__": + main() diff --git a/bayesian_torch/examples/quantization_test.py b/bayesian_torch/examples/quantization_test.py new file mode 100644 index 0000000..bc18c25 --- /dev/null +++ b/bayesian_torch/examples/quantization_test.py @@ -0,0 +1,34 @@ +# import torch +# import bayesian_torch +# from bayesian_torch.ao.quantization import prepare, convert +# import bayesian_torch.models.bayesian.resnet_variational_large as resnet +# from bayesian_torch.models.bnn_to_qbnn import bnn_to_qbnn + +# model = resnet.__dict__['resnet50']() + +# input = torch.randn(1,3,224,224) +# mp = prepare(model) +# mp(input) # haven't replaced the batchnorm layer +# qmodel = torch.quantization.convert(mp) +# bnn_to_qbnn(qmodel) + + +import torch +import bayesian_torch +import bayesian_torch.models.bayesian.resnet_variational_large as resnet + +m = resnet.__dict__['resnet50']() +# alternative way to construct a bnn model +# from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn +# m = torchvision.models.resnet50(weights="IMAGENET1K_V1") +# dnn_to_bnn(m) + + + +mp = bayesian_torch.quantization.prepare(m) +input = torch.randn(1,3,224,224) +mp(input) # calibration +mq = bayesian_torch.quantization.convert(mp) + + + diff --git a/bayesian_torch/layers/batchnorm.py b/bayesian_torch/layers/batchnorm.py index 145997c..25ab8f3 100644 --- a/bayesian_torch/layers/batchnorm.py +++ b/bayesian_torch/layers/batchnorm.py @@ -54,7 +54,6 @@ def _check_input_dim(self, input): input.dim())) def forward(self, input): - self._check_input_dim(input[0]) exponential_average_factor = 0.0 if self.training and self.track_running_stats: self.num_batches_tracked += 1 @@ -63,13 +62,21 @@ def forward(self, input): else: # use exponential moving average exponential_average_factor = self.momentum - out = F.batch_norm(input[0], self.running_mean, self.running_var, - self.weight, self.bias, self.training - or not self.track_running_stats, - exponential_average_factor, self.eps) - kl = 0 - return out, kl - + if len(input) == 2: + self._check_input_dim(input[0]) + out = F.batch_norm(input[0], self.running_mean, self.running_var, + self.weight, self.bias, self.training + or not self.track_running_stats, + exponential_average_factor, self.eps) + kl = 0 + return out, kl + else: + out = F.batch_norm(input, self.running_mean, self.running_var, + self.weight, self.bias, self.training + or not self.track_running_stats, + exponential_average_factor, self.eps) + return out + class BatchNorm1dLayer(nn.Module): def __init__(self, diff --git a/bayesian_torch/layers/flipout_layers/__init__.py b/bayesian_torch/layers/flipout_layers/__init__.py index 3aeb698..b1b18c4 100644 --- a/bayesian_torch/layers/flipout_layers/__init__.py +++ b/bayesian_torch/layers/flipout_layers/__init__.py @@ -1,3 +1,6 @@ from .conv_flipout import * from .linear_flipout import * from .rnn_flipout import * +from .quantized_linear_flipout import * +from .quantized_conv_flipout import * +# from .quantize_rnn_flipout import * \ No newline at end of file diff --git a/bayesian_torch/layers/flipout_layers/conv_flipout.py b/bayesian_torch/layers/flipout_layers/conv_flipout.py index c92d24b..2ad0679 100644 --- a/bayesian_torch/layers/flipout_layers/conv_flipout.py +++ b/bayesian_torch/layers/flipout_layers/conv_flipout.py @@ -37,6 +37,9 @@ import torch.nn as nn import torch.nn.functional as F from ..base_variational_layer import BaseVariationalLayer_, get_kernel_size +from torch.quantization.observer import HistogramObserver, PerChannelMinMaxObserver, MinMaxObserver +from torch.quantization.qconfig import QConfig + from torch.distributions.normal import Normal from torch.distributions.uniform import Uniform @@ -136,6 +139,15 @@ def __init__(self, self.register_buffer('prior_bias_sigma', None, persistent=False) self.init_parameters() + self.quant_prepare=False + + def prepare(self): + self.qint_quant = nn.ModuleList([torch.quantization.QuantStub( + QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(4)]) + self.quint_quant = nn.ModuleList([torch.quantization.QuantStub( + QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)]) + self.dequant = torch.quantization.DeQuantStub() + self.quant_prepare=True def init_parameters(self): # prior values @@ -303,6 +315,15 @@ def __init__(self, self.register_buffer('prior_bias_sigma', None, persistent=False) self.init_parameters() + self.quant_prepare=False + + def prepare(self): + self.qint_quant = nn.ModuleList([torch.quantization.QuantStub( + QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(4)]) + self.quint_quant = nn.ModuleList([torch.quantization.QuantStub( + QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)]) + self.dequant = torch.quantization.DeQuantStub() + self.quant_prepare=True def init_parameters(self): # prior values @@ -365,18 +386,39 @@ def forward(self, x, return_kl=True): self.prior_bias_sigma) # perturbed feedforward - perturbed_outputs = F.conv2d(x * sign_input, + x_tmp = x * sign_input + perturbed_outputs_tmp = F.conv2d(x * sign_input, weight=delta_kernel, bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation, - groups=self.groups) * sign_output + groups=self.groups) + perturbed_outputs = perturbed_outputs_tmp * sign_output + out = outputs + perturbed_outputs + + if self.quant_prepare: + # quint8 quantstub + x = self.quint_quant[0](x) # input + outputs = self.quint_quant[1](outputs) # output + sign_input = self.quint_quant[2](sign_input) + sign_output = self.quint_quant[3](sign_output) + x_tmp = self.quint_quant[4](x_tmp) + perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output + perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output + out = self.quint_quant[7](out) # output + + # qint8 quantstub + sigma_weight = self.qint_quant[0](sigma_weight) # weight + mu_kernel = self.qint_quant[1](self.mu_kernel) # weight + eps_kernel = self.qint_quant[2](eps_kernel) # random variable + delta_kernel =self.qint_quant[3](delta_kernel) # multiply activation # returning outputs + perturbations if return_kl: - return outputs + perturbed_outputs, kl - return outputs + perturbed_outputs + return out, kl + return out + class Conv3dFlipout(BaseVariationalLayer_): diff --git a/bayesian_torch/layers/flipout_layers/linear_flipout.py b/bayesian_torch/layers/flipout_layers/linear_flipout.py index af34d5d..aa6f702 100644 --- a/bayesian_torch/layers/flipout_layers/linear_flipout.py +++ b/bayesian_torch/layers/flipout_layers/linear_flipout.py @@ -40,6 +40,8 @@ from torch.distributions.normal import Normal from torch.distributions.uniform import Uniform from ..base_variational_layer import BaseVariationalLayer_ +from torch.quantization.observer import HistogramObserver, PerChannelMinMaxObserver, MinMaxObserver +from torch.quantization.qconfig import QConfig __all__ = ["LinearFlipout"] @@ -107,6 +109,15 @@ def __init__(self, self.register_buffer('eps_bias', None, persistent=False) self.init_parameters() + self.quant_prepare=False + + def prepare(self): + self.qint_quant = nn.ModuleList([torch.quantization.QuantStub( + QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(4)]) + self.quint_quant = nn.ModuleList([torch.quantization.QuantStub( + QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)]) + self.dequant = torch.quantization.DeQuantStub() + self.quant_prepare=True def init_parameters(self): # init prior mu @@ -136,7 +147,9 @@ def forward(self, x, return_kl=True): return_kl = False # sampling delta_W sigma_weight = torch.log1p(torch.exp(self.rho_weight)) - delta_weight = (sigma_weight * self.eps_weight.data.normal_()) + eps_weight = self.eps_weight.data.normal_() + delta_weight = sigma_weight * eps_weight + # delta_weight = (sigma_weight * self.eps_weight.data.normal_()) # get kl divergence if return_kl: @@ -153,14 +166,33 @@ def forward(self, x, return_kl=True): # linear outputs outputs = F.linear(x, self.mu_weight, self.mu_bias) - sign_input = x.clone().uniform_(-1, 1).sign() sign_output = outputs.clone().uniform_(-1, 1).sign() - - perturbed_outputs = F.linear(x * sign_input, delta_weight, - bias) * sign_output + x_tmp = x * sign_input + perturbed_outputs_tmp = F.linear(x_tmp, delta_weight, bias) + perturbed_outputs = perturbed_outputs_tmp * sign_output + out = outputs + perturbed_outputs + + if self.quant_prepare: + # quint8 quantstub + x = self.quint_quant[0](x) # input + outputs = self.quint_quant[1](outputs) # output + sign_input = self.quint_quant[2](sign_input) + sign_output = self.quint_quant[3](sign_output) + x_tmp = self.quint_quant[4](x_tmp) + perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output + perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output + out = self.quint_quant[7](out) # output + + # qint8 quantstub + sigma_weight = self.qint_quant[0](sigma_weight) # weight + mu_weight = self.qint_quant[1](self.mu_weight) # weight + eps_weight = self.qint_quant[2](eps_weight) # random variable + delta_weight =self.qint_quant[3](delta_weight) # multiply activation + # returning outputs + perturbations if return_kl: - return outputs + perturbed_outputs, kl - return outputs + perturbed_outputs + return out, kl + return out + diff --git a/bayesian_torch/layers/flipout_layers/quantized_conv_flipout.py b/bayesian_torch/layers/flipout_layers/quantized_conv_flipout.py new file mode 100644 index 0000000..4be011a --- /dev/null +++ b/bayesian_torch/layers/flipout_layers/quantized_conv_flipout.py @@ -0,0 +1,1351 @@ +# Copyright (C) 2021 Intel Labs +# +# BSD-3-Clause License +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# +# Convolutional layers with flipout Monte Carlo weight estimator to perform +# variational inference in Bayesian neural networks. Variational layers +# enables Monte Carlo approximation of the distribution over the kernel +# +# +# ====================================================================================== +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter +from ..base_variational_layer import BaseVariationalLayer_ +from .conv_flipout import * +import random + +from torch.distributions.normal import Normal +from torch.distributions.uniform import Uniform + +__all__ = [ + 'QuantizedConv1dFlipout', + 'QuantizedConv2dFlipout', + 'QuantizedConv3dFlipout', + 'QuantizedConvTranspose1dFlipout', + 'QuantizedConvTranspose2dFlipout', + 'QuantizedConvTranspose3dFlipout', +] + + +class QuantizedConv1dFlipout(Conv1dFlipout): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + super(QuantizedConv1dFlipout).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias) + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format. + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + x: tensors + Input tensor. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + if x.dtype!=torch.quint8: + x = torch.quantize_per_tensor(x, default_scale, default_zero_point, torch.quint8) + + bias = None + if self.bias: + bias = self.quantized_mu_bias + + outputs = torch.nn.quantized.functional.conv1d(x, self.quantized_mu_weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + + # sampling perturbation signs + sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8) + + # getting perturbation weights + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) + delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + + bias = None + if self.bias: + eps_bias = self.eps_bias.data.normal_() + bias = (self.quantized_sigma_bias * eps_bias) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point) + + perturbed_outputs = torch.nn.quantized.functional.conv1d(x, + weight=delta_kernel, bias=bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups, scale=default_scale, zero_point=default_zero_point) + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point) + out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point) + + if return_kl: + return out, 0 + + return out + + +class QuantizedConv2dFlipout(Conv2dFlipout): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): # be aware of bias + """ + + """ + super(QuantizedConv2dFlipout, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias) + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + self.quant_dict = None + self.presampled_input_perturb = None + self.presampled_output_perturb = None + + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format. + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + x: tensors + Input tensor. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + bias = None + if self.bias: + bias = self.quantized_mu_bias # TODO: check correctness + + if self.quant_dict is not None: + # getting perturbation weights + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), self.quant_dict[0]['scale'], self.quant_dict[0]['zero_point'], torch.qint8) + delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, self.quant_dict[1]['scale'], self.quant_dict[1]['zero_point']) + + if x.dtype!=torch.quint8: # check if input has been quantized + x = torch.quantize_per_tensor(x, self.quant_dict[2]['scale'], self.quant_dict[2]['zero_point'], torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + outputs = torch.nn.quantized.functional.conv2d(x, self.quantized_mu_weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=self.quant_dict[3]['scale'], zero_point=self.quant_dict[3]['zero_point']) # input: quint8, weight: qint8, bias: fp32 + + # sampling perturbation signs + input_tsize = torch.prod(torch.tensor(x.shape))*1 + output_tsize = torch.prod(torch.tensor(outputs.shape))*1 + + if self.presampled_input_perturb is None: + self.presampled_input_perturb = torch.randint(0, 1, (input_tsize + torch.prod(torch.tensor(x.shape)),)).float() + self.presampled_input_perturb[self.presampled_input_perturb==0] = -1 + + if self.presampled_output_perturb is None: + self.presampled_output_perturb = torch.randint(0, 1, (output_tsize + torch.prod(torch.tensor(outputs.shape)),)).float() + self.presampled_output_perturb[self.presampled_output_perturb==0] = -1 + + st = random.randint(0, input_tsize) + sign_input = self.presampled_input_perturb[st:st+torch.prod(torch.tensor(x.shape))].reshape(x.shape) + + st = random.randint(0, output_tsize) + sign_output = self.presampled_output_perturb[st:st+torch.prod(torch.tensor(outputs.shape))].reshape(outputs.shape) + # sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + # sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, self.quant_dict[4]['scale'], self.quant_dict[4]['zero_point'], torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, self.quant_dict[5]['scale'], self.quant_dict[5]['zero_point'], torch.quint8) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, self.quant_dict[6]['scale'], self.quant_dict[6]['zero_point']) + perturbed_outputs = torch.nn.quantized.functional.conv2d(x, + weight=delta_kernel, bias=bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups, scale=self.quant_dict[7]['scale'], zero_point=self.quant_dict[7]['zero_point']) + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, self.quant_dict[8]['scale'], self.quant_dict[8]['zero_point']) + out = torch.ops.quantized.add(outputs, perturbed_outputs, self.quant_dict[9]['scale'], self.quant_dict[9]['zero_point']) + # out = out.dequantize() + + else: + if x.dtype!=torch.quint8: + x = torch.quantize_per_tensor(x, default_scale, default_zero_point, torch.quint8) + + outputs = torch.nn.quantized.functional.conv2d(x, self.quantized_mu_weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + + # sampling perturbation signs + sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8) + + # getting perturbation weights + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) + delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + + bias = None + if self.bias: + eps_bias = self.eps_bias.data.normal_() + bias = (self.quantized_sigma_bias * eps_bias) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point) + + perturbed_outputs = torch.nn.quantized.functional.conv2d(x, + weight=delta_kernel, bias=bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups, scale=default_scale, zero_point=default_zero_point) + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point) + out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point) + + if return_kl: + return out, 0 + + return out + + +class QuantizedConv3dFlipout(Conv3dFlipout): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + super(QuantizedConv3dFlipout).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias) + + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format. + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + x: tensors + Input tensor. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + if x.dtype!=torch.quint8: + x = torch.quantize_per_tensor(x, default_scale, default_zero_point, torch.quint8) + + bias = None + if self.bias: + bias = self.quantized_mu_bias + + outputs = torch.nn.quantized.functional.conv3d(x, self.quantized_mu_weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + + # sampling perturbation signs + sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8) + + # getting perturbation weights + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) + delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + + bias = None + if self.bias: + eps_bias = self.eps_bias.data.normal_() + bias = (self.quantized_sigma_bias * eps_bias) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point) + + perturbed_outputs = torch.nn.quantized.functional.conv3d(x, + weight=delta_kernel, bias=bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups, scale=default_scale, zero_point=default_zero_point) + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point) + out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point) + + if return_kl: + return out, 0 + + return out + +class QuantizedConvTranspose1dFlipout(ConvTranspose1dFlipout): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + super(QuantizedConvTranspose1dFlipout).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias) + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + if not hasattr(self, "output_padding"): + self.output_padding = 0 + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format. + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + x: tensors + Input tensor. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + if x.dtype!=torch.quint8: + x = torch.quantize_per_tensor(x, default_scale, default_zero_point, torch.quint8) + + bias = None + if self.bias: + bias = self.quantized_mu_bias + + self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(self.quantized_mu_weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + outputs = torch.ops.quantized.conv_transpose1d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + # sampling perturbation signs + sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8) + + # getting perturbation weights + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) + delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + + bias = None + if self.bias: + eps_bias = self.eps_bias.data.normal_() + bias = (self.quantized_sigma_bias * eps_bias) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point) + + self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(delta_kernel, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + perturbed_outputs = torch.ops.quantized.conv_transpose1d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point) + out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point) + + if return_kl: + return out, 0 + + return out + +class QuantizedConvTranspose2dFlipout(ConvTranspose2dFlipout): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + super(QuantizedConvTranspose2dFlipout).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias) + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + if not hasattr(self, "output_padding"): + self.output_padding = 0 + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format. + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + x: tensors + Input tensor. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + if x.dtype!=torch.quint8: + x = torch.quantize_per_tensor(x, default_scale, default_zero_point, torch.quint8) + + bias = None + if self.bias: + bias = self.quantized_mu_bias + + self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(self.quantized_mu_weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + outputs = torch.ops.quantized.conv_transpose2d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + # sampling perturbation signs + sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8) + + # getting perturbation weights + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) + delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + + bias = None + if self.bias: + eps_bias = self.eps_bias.data.normal_() + bias = (self.quantized_sigma_bias * eps_bias) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point) + + self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(delta_kernel, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + perturbed_outputs = torch.ops.quantized.conv_transpose2d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point) + out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point) + + if return_kl: + return out, 0 + + return out + +class QuantizedConvTranspose3dFlipout(ConvTranspose3dFlipout): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + super(QuantizedConvTranspose3dFlipout).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias) + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + if not hasattr(self, "output_padding"): + self.output_padding = 0 + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format. + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + x: tensors + Input tensor. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + if x.dtype!=torch.quint8: + x = torch.quantize_per_tensor(x, default_scale, default_zero_point, torch.quint8) + + bias = None + if self.bias: + bias = self.quantized_mu_bias + + self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(self.quantized_mu_weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + outputs = torch.ops.quantized.conv_transpose3d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + # sampling perturbation signs + sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8) + + # getting perturbation weights + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) + delta_kernel = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + + bias = None + if self.bias: + eps_bias = self.eps_bias.data.normal_() + bias = (self.quantized_sigma_bias * eps_bias) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point) + + self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(delta_kernel, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + perturbed_outputs = torch.ops.quantized.conv_transpose3d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point) + out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point) + + if return_kl: + return out, 0 + + return out \ No newline at end of file diff --git a/bayesian_torch/layers/flipout_layers/quantized_linear_flipout.py b/bayesian_torch/layers/flipout_layers/quantized_linear_flipout.py new file mode 100644 index 0000000..3cce873 --- /dev/null +++ b/bayesian_torch/layers/flipout_layers/quantized_linear_flipout.py @@ -0,0 +1,261 @@ +# Copyright (C) 2021 Intel Labs +# +# BSD-3-Clause License +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# +# Linear Flipout Layers with flipout weight estimator to perform +# variational inference in Bayesian neural networks. Variational layers +# enables Monte Carlo approximation of the distribution over the weights +# +# @authors: Jun-Liang Lin +# +# ====================================================================================== +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Module, Parameter +from torch.distributions.normal import Normal +from torch.distributions.uniform import Uniform +import random + +from .linear_flipout import LinearFlipout + +__all__ = ["QuantizedLinearFlipout"] + +class QuantizedLinearFlipout(LinearFlipout): + def __init__(self, + in_features, + out_features): + + super(QuantizedLinearFlipout, self).__init__( + in_features, + out_features) + + self.is_dequant = False + self.quant_dict = None + self.presampled_input_perturb = None + self.presampled_output_perturb = None + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_weight), requires_grad=False) + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_weight))), requires_grad=False) + delattr(self, "mu_weight") + delattr(self, "rho_weight") + + self.quantized_mu_bias = self.mu_bias#Parameter(self.get_quantized_tensor(self.mu_bias), requires_grad=False) + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False)#Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_bias))), requires_grad=False) + delattr(self, "mu_bias") + delattr(self, "rho_bias") + + def dequantize(self): + self.mu_weight = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + return + + def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + x: tensors + Input tensor. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. Already dequantized. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + bias = None + if self.quantized_mu_bias is not None: + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + bias = self.mu_bias + + if self.quant_dict is not None: + + # getting perturbation weights + eps_weight = torch.quantize_per_tensor(self.eps_weight.data.normal_(), self.quant_dict[0]['scale'], self.quant_dict[0]['zero_point'], torch.qint8) + delta_weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_weight, self.quant_dict[1]['scale'], self.quant_dict[1]['zero_point']) + + bias = None + if self.quantized_sigma_bias is not None: + eps_bias = self.eps_bias.data.normal_() + bias = (self.sigma_bias * eps_bias) + + if x.dtype!=torch.quint8: # check if input has been quantized + x = torch.quantize_per_tensor(x, self.quant_dict[2]['scale'], self.quant_dict[2]['zero_point'], torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + outputs = torch.nn.quantized.functional.linear(x, self.quantized_mu_weight, bias, scale=self.quant_dict[3]['scale'], zero_point=self.quant_dict[3]['zero_point']) # input: quint8, weight: qint8, bias: fp32 + + # sampling perturbation signs + # sampling perturbation signs + input_tsize = torch.prod(torch.tensor(x.shape))*1 + output_tsize = torch.prod(torch.tensor(outputs.shape))*1 + + if self.presampled_input_perturb is None: + self.presampled_input_perturb = torch.randint(0, 1, (input_tsize + torch.prod(torch.tensor(x.shape)),)).float() + self.presampled_input_perturb[self.presampled_input_perturb==0] = -1 + + if self.presampled_output_perturb is None: + self.presampled_output_perturb = torch.randint(0, 1, (output_tsize + torch.prod(torch.tensor(outputs.shape)),)).float() + self.presampled_output_perturb[self.presampled_output_perturb==0] = -1 + + st = random.randint(0, input_tsize) + sign_input = self.presampled_input_perturb[st:st+torch.prod(torch.tensor(x.shape))].reshape(x.shape) + + st = random.randint(0, output_tsize) + sign_output = self.presampled_output_perturb[st:st+torch.prod(torch.tensor(outputs.shape))].reshape(outputs.shape) + + + # sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + # sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, self.quant_dict[4]['scale'], self.quant_dict[4]['zero_point'], torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, self.quant_dict[5]['scale'], self.quant_dict[5]['zero_point'], torch.quint8) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, self.quant_dict[6]['scale'], self.quant_dict[6]['zero_point']) + perturbed_outputs = torch.nn.quantized.functional.linear(x, + weight=delta_weight, bias=bias, scale=self.quant_dict[7]['scale'], zero_point=self.quant_dict[7]['zero_point']) + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, self.quant_dict[8]['scale'], self.quant_dict[8]['zero_point']) + out = torch.ops.quantized.add(outputs, perturbed_outputs, self.quant_dict[9]['scale'], self.quant_dict[9]['zero_point']) + out = out.dequantize() + + else: + + outputs = torch.nn.quantized.functional.linear(x, self.quantized_mu_weight, bias, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + + # sampling perturbation signs + sign_input = torch.zeros(x.shape).uniform_(-1, 1).sign() + sign_output = torch.zeros(outputs.shape).uniform_(-1, 1).sign() + sign_input = torch.quantize_per_tensor(sign_input, default_scale, default_zero_point, torch.quint8) + sign_output = torch.quantize_per_tensor(sign_output, default_scale, default_zero_point, torch.quint8) + + # getting perturbation weights + eps_weight = torch.quantize_per_tensor(self.eps_weight.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_weight.q_scale()) + delta_weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_weight, new_scale, 0) + + bias = None + if self.quantized_sigma_bias is not None: + eps_bias = self.eps_bias.data.normal_() + bias = (self.sigma_bias * eps_bias) + + # perturbed feedforward + x = torch.ops.quantized.mul(x, sign_input, default_scale, default_zero_point) + + perturbed_outputs = torch.nn.quantized.functional.linear(x, + weight=delta_weight, bias=bias, scale=default_scale, zero_point=default_zero_point) + perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point) + out = torch.ops.quantized.add(outputs, perturbed_outputs, default_scale, default_zero_point) + out = out.dequantize() + + if return_kl: + return out, 0 + + return out diff --git a/bayesian_torch/layers/variational_layers/__init__.py b/bayesian_torch/layers/variational_layers/__init__.py index 1c083e3..6fae454 100644 --- a/bayesian_torch/layers/variational_layers/__init__.py +++ b/bayesian_torch/layers/variational_layers/__init__.py @@ -1,3 +1,6 @@ from .linear_variational import * from .conv_variational import * from .rnn_variational import * +from .quantize_linear_variational import * +from .quantize_conv_variational import * +# from .quantize_rnn_variational import * \ No newline at end of file diff --git a/bayesian_torch/layers/variational_layers/conv_variational.py b/bayesian_torch/layers/variational_layers/conv_variational.py index 0d2ebfd..403651a 100644 --- a/bayesian_torch/layers/variational_layers/conv_variational.py +++ b/bayesian_torch/layers/variational_layers/conv_variational.py @@ -48,6 +48,8 @@ from torch.nn import Parameter from ..base_variational_layer import BaseVariationalLayer_, get_kernel_size import math +from torch.quantization.observer import HistogramObserver, PerChannelMinMaxObserver, MinMaxObserver +from torch.quantization.qconfig import QConfig __all__ = [ 'Conv1dReparameterization', @@ -295,6 +297,15 @@ def __init__(self, self.register_buffer('prior_bias_sigma', None, persistent=False) self.init_parameters() + self.quant_prepare=False + + def prepare(self): + self.qint_quant = nn.ModuleList([torch.quantization.QuantStub( + QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(5)]) + self.quint_quant = nn.ModuleList([torch.quantization.QuantStub( + QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(2)]) + self.dequant = torch.quantization.DeQuantStub() + self.quant_prepare=True def init_parameters(self): self.prior_weight_mu.fill_(self.prior_mean) @@ -325,7 +336,8 @@ def forward(self, input, return_kl=True): sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) eps_kernel = self.eps_kernel.data.normal_() - weight = self.mu_kernel + (sigma_weight * eps_kernel) + tmp_result = sigma_weight * eps_kernel + weight = self.mu_kernel + tmp_result if return_kl: kl_weight = self.kl_div(self.mu_kernel, sigma_weight, @@ -342,13 +354,27 @@ def forward(self, input, return_kl=True): out = F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) + + if self.quant_prepare: + # quint8 quantstub + input = self.quint_quant[0](input) # input + out = self.quint_quant[1](out) # output + + # qint8 quantstub + sigma_weight = self.qint_quant[0](sigma_weight) # weight + mu_kernel = self.qint_quant[1](self.mu_kernel) # weight + eps_kernel = self.qint_quant[2](eps_kernel) # random variable + tmp_result =self.qint_quant[3](tmp_result) # multiply activation + weight = self.qint_quant[4](weight) # add activatation + + if return_kl: if self.bias: kl = kl_weight + kl_bias else: kl = kl_weight return out, kl - + return out @@ -946,3 +972,4 @@ def forward(self, input, return_kl=True): return out, kl return out + diff --git a/bayesian_torch/layers/variational_layers/linear_variational.py b/bayesian_torch/layers/variational_layers/linear_variational.py index 7efb667..4cb1adb 100644 --- a/bayesian_torch/layers/variational_layers/linear_variational.py +++ b/bayesian_torch/layers/variational_layers/linear_variational.py @@ -47,6 +47,8 @@ from torch.nn import Module, Parameter from ..base_variational_layer import BaseVariationalLayer_ import math +from torch.quantization.observer import HistogramObserver, PerChannelMinMaxObserver, MinMaxObserver +from torch.quantization.qconfig import QConfig class LinearReparameterization(BaseVariationalLayer_): @@ -116,6 +118,15 @@ def __init__(self, self.register_buffer('eps_bias', None, persistent=False) self.init_parameters() + self.quant_prepare=False + + def prepare(self): + self.qint_quant = nn.ModuleList([torch.quantization.QuantStub( + QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(5)]) + self.quint_quant = nn.ModuleList([torch.quantization.QuantStub( + QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(2)]) + self.dequant = torch.quantization.DeQuantStub() + self.quant_prepare=True def init_parameters(self): self.prior_weight_mu.fill_(self.prior_mean) @@ -147,8 +158,11 @@ def forward(self, input, return_kl=True): if self.dnn_to_bnn_flag: return_kl = False sigma_weight = torch.log1p(torch.exp(self.rho_weight)) - weight = self.mu_weight + \ - (sigma_weight * self.eps_weight.data.normal_()) + eps_weight = self.eps_weight.data.normal_() + tmp_result = sigma_weight * eps_weight + weight = self.mu_weight + tmp_result + + if return_kl: kl_weight = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma) @@ -162,6 +176,20 @@ def forward(self, input, return_kl=True): self.prior_bias_sigma) out = F.linear(input, weight, bias) + + if self.quant_prepare: + # quint8 quantstub + input = self.quint_quant[0](input) # input + out = self.quint_quant[1](out) # output + + # qint8 quantstub + sigma_weight = self.qint_quant[0](sigma_weight) # weight + mu_weight = self.qint_quant[1](self.mu_weight) # weight + eps_weight = self.qint_quant[2](eps_weight) # random variable + tmp_result =self.qint_quant[3](tmp_result) # multiply activation + weight = self.qint_quant[4](weight) # add activatation + + if return_kl: if self.mu_bias is not None: kl = kl_weight + kl_bias diff --git a/bayesian_torch/layers/variational_layers/quantize_conv_variational.py b/bayesian_torch/layers/variational_layers/quantize_conv_variational.py new file mode 100644 index 0000000..31ed9e7 --- /dev/null +++ b/bayesian_torch/layers/variational_layers/quantize_conv_variational.py @@ -0,0 +1,1492 @@ +# Copyright (C) 2021 Intel Labs +# +# BSD-3-Clause License +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# @authors: Jun-Liang Lin +# +# ====================================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter +from ..base_variational_layer import BaseVariationalLayer_ +from .conv_variational import * +import math + +__all__ = [ + 'QuantizedConv1dReparameterization', + 'QuantizedConv2dReparameterization', + 'QuantizedConv3dReparameterization', + 'QuantizedConvTranspose1dReparameterization', + 'QuantizedConvTranspose2dReparameterization', + 'QuantizedConvTranspose3dReparameterization', +] + + +class QuantizedConv1dReparameterization(Conv1dReparameterization): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + + super(QuantizedConv1dReparameterization, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias + ) + + ## redundant ## + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + ## redundant ## + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + self.quant_dict = None + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): # Deprecated. Only for forward mode #1. + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + input: tensors + Input tensor. + + enable_int8_compute: bool, optional + Whether to enable int8 computation. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + + if self.dnn_to_bnn_flag: + return_kl = False + + if self.quant_dict is not None: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), self.quant_dict[0]['scale'], self.quant_dict[0]['zero_point'], torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, self.quant_dict[1]['scale'], self.quant_dict[1]['zero_point']) + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, self.quant_dict[2]['scale'], self.quant_dict[2]['zero_point']) + bias = None + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, self.quant_dict[3]['scale'], self.quant_dict[3]['zero_point'], torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + out = torch.nn.quantized.functional.conv1d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=self.quant_dict[4]['scale'], zero_point=self.quant_dict[4]['zero_point']) # input: quint8, weight: qint8, bias: fp32 + + elif not enable_int8_compute: # Deprecated. Use this method for reducing model size only. + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + + weight = self.mu_kernel + (self.sigma_weight * self.eps_kernel.data.normal_()) + bias = None + + if self.bias: + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + + out = F.conv1d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups) + + else: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) # Calculate the new scale after multiplying two quantized tensors. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + new_scale = max(new_scale, self.quantized_mu_weight.q_scale()) # Calculate the new scale after adding two quantized tensors. + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, new_scale, 0) + bias = None + + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, default_scale, default_zero_point, torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + out = torch.nn.quantized.functional.conv1d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + + if return_kl: + return out, 0 # disable kl divergence computing + + return out + + + +class QuantizedConv2dReparameterization(Conv2dReparameterization): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + + """ + + super(QuantizedConv2dReparameterization, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias + ) + + ## redundant ## + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + ## redundant ## + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + self.quant_dict = None + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + delattr(self, "qint_quant") + delattr(self, "quint_quant") + delattr(self, "dequant") + + def dequantize(self): # Deprecated. Only for forward mode #1. + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + input: tensors + Input tensor. + + enable_int8_compute: bool, optional + Whether to enable int8 computation. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + if self.dnn_to_bnn_flag: + return_kl = False + + if self.quant_dict is not None: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), self.quant_dict[0]['scale'], self.quant_dict[0]['zero_point'], torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, self.quant_dict[1]['scale'], self.quant_dict[1]['zero_point']) + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, self.quant_dict[2]['scale'], self.quant_dict[2]['zero_point']) + bias = None + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, self.quant_dict[3]['scale'], self.quant_dict[3]['zero_point'], torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + out = torch.nn.quantized.functional.conv2d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=self.quant_dict[4]['scale'], zero_point=self.quant_dict[4]['zero_point']) # input: quint8, weight: qint8, bias: fp32 + + elif not enable_int8_compute: # Deprecated. Use this method for reducing model size only. + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + + weight = self.mu_kernel + (self.sigma_weight * self.eps_kernel.data.normal_()) + bias = None + + if self.bias: + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + + out = F.conv2d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups) + + else: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6. + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) # Calculate the new scale after multiplying two quantized tensors. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + new_scale = max(new_scale, self.quantized_mu_weight.q_scale()) # Calculate the new scale after adding two quantized tensors. + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, new_scale, 0) + bias = None + + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, default_scale, default_zero_point, torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + out = torch.nn.quantized.functional.conv2d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + + if return_kl: + return out, 0 # disable kl divergence computing + + return out + + +class QuantizedConv3dReparameterization(Conv3dReparameterization): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + + super(QuantizedConv3dReparameterization, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias + ) + + ## redundant ## + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + ## redundant ## + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + self.quant_dict = None + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): # Deprecated. Only for forward mode #1. + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + input: tensors + Input tensor. + + enable_int8_compute: bool, optional + Whether to enable int8 computation. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + if self.dnn_to_bnn_flag: + return_kl = False + + if self.quant_dict is not None: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), self.quant_dict[0]['scale'], self.quant_dict[0]['zero_point'], torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, self.quant_dict[1]['scale'], self.quant_dict[1]['zero_point']) + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, self.quant_dict[2]['scale'], self.quant_dict[2]['zero_point']) + bias = None + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, self.quant_dict[3]['scale'], self.quant_dict[3]['zero_point'], torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + out = torch.nn.quantized.functional.conv3d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=self.quant_dict[4]['scale'], zero_point=self.quant_dict[4]['zero_point']) # input: quint8, weight: qint8, bias: fp32 + + elif not enable_int8_compute: # Deprecated. Use this method for reducing model size only. + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + + weight = self.mu_kernel + (self.sigma_weight * self.eps_kernel.data.normal_()) + bias = None + + if self.bias: + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + + out = F.conv3d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups) + + else: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6. + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) # Calculate the new scale after multiplying two quantized tensors. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + new_scale = max(new_scale, self.quantized_mu_weight.q_scale()) # Calculate the new scale after adding two quantized tensors. + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, new_scale, 0) + bias = None + + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, default_scale, default_zero_point, torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + out = torch.nn.quantized.functional.conv3d(input, weight, bias, self.stride, self.padding, + self.dilation, self.groups, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + + if return_kl: + return out, 0 # disable kl divergence computing + + return out + +class QuantizedConvTranspose1dReparameterization(ConvTranspose1dReparameterization): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + + super(ConvTranspose1dReparameterization, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias + ) + + ## redundant ## + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + ## redundant ## + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): # Deprecated. Only for forward mode #1. + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + input: tensors + Input tensor. + + enable_int8_compute: bool, optional + Whether to enable int8 computation. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + if self.dnn_to_bnn_flag: + return_kl = False + + if not enable_int8_compute: # Deprecated. Use this method for reducing model size only. + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + + weight = self.mu_kernel + (self.sigma_weight * self.eps_kernel.data.normal_()) + bias = None + + if self.bias: + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + + out = F.conv_transpose1d(input, weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + else: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6. + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) # Calculate the new scale after multiplying two quantized tensors. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + new_scale = max(new_scale, self.quantized_mu_weight.q_scale()) # Calculate the new scale after adding two quantized tensors. + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, new_scale, 0) + bias = None + + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, default_scale, default_zero_point, torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + out = torch.ops.quantized.conv_transpose1d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + + if return_kl: + return out, 0 # disable kl divergence computing + + return out + +class QuantizedConvTranspose2dReparameterization(ConvTranspose2dReparameterization): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + + super(ConvTranspose2dReparameterization, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias + ) + + ## redundant ## + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + ## redundant ## + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): # Deprecated. Only for forward mode #1. + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + input: tensors + Input tensor. + + enable_int8_compute: bool, optional + Whether to enable int8 computation. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + if self.dnn_to_bnn_flag: + return_kl = False + + if not enable_int8_compute: # Deprecated. Use this method for reducing model size only. + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + + weight = self.mu_kernel + (self.sigma_weight * self.eps_kernel.data.normal_()) + bias = None + + if self.bias: + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + + out = F.conv_transpose2d(input, weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + else: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6. + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) # Calculate the new scale after multiplying two quantized tensors. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + new_scale = max(new_scale, self.quantized_mu_weight.q_scale()) # Calculate the new scale after adding two quantized tensors. + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, new_scale, 0) + bias = None + + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, default_scale, default_zero_point, torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + out = torch.ops.quantized.conv_transpose2d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + + if return_kl: + return out, 0 # disable kl divergence computing + + return out + +class QuantizedConvTranspose3dReparameterization(ConvTranspose3dReparameterization): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + """ + """ + + super(ConvTranspose3dReparameterization, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias + ) + + ## redundant ## + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + ## redundant ## + + # for conv bn fusion + self.bn_weight = None + self.bn_bias = None + self.bn_running_mean = None + self.bn_running_var = None + self.bn_eps = None + + self.is_dequant = False + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + if self.bn_weight is None: # has batchnorm layer, no bn fusion + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))), requires_grad=False).cpu() + else: # fuse conv and bn + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_kernel*(bn_coef.view(-1,1,1,1).expand(self.mu_kernel.shape))), requires_grad=False).cpu() + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_kernel))*(bn_coef.view(-1,1,1,1).expand(self.rho_kernel.shape))), requires_grad=False).cpu() + delattr(self, "mu_kernel") + delattr(self, "rho_kernel") + + + ## DO NOT QUANTIZE BIAS!!!! Bias should be in fp32 format + ## Variable names may be confusing. We don't quantize them. + ## TODO: rename variables + if self.bias: # if has bias + if self.bn_weight is None: # if no bn fusion + self.quantized_mu_bias = Parameter(self.mu_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False).cpu() + else: # if apply bn fusion + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps) + self.quantized_mu_bias = Parameter((self.mu_bias-self.bn_running_mean)*bn_coef+self.bn_bias, requires_grad=False).cpu() + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias))*bn_coef, requires_grad=False).cpu() + delattr(self, "mu_bias") + delattr(self, "rho_bias") + else: + if self.bn_weight is not None: # if no bias but apply bn fusion + self.bias = True + bn_coef = self.bn_weight/torch.sqrt(self.bn_running_var+self.bn_eps)*(-self.bn_running_mean)+self.bn_bias + self.quantized_mu_bias = Parameter(bn_coef, requires_grad=False).cpu() + self.quantized_sigma_bias = None + + delattr(self, "bn_weight") + delattr(self, "bn_bias") + delattr(self, "bn_running_mean") + delattr(self, "bn_running_var") + delattr(self, "bn_eps") + + def dequantize(self): # Deprecated. Only for forward mode #1. + self.mu_kernel = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + if self.bias: + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + + return + + def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.1, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + input: tensors + Input tensor. + + enable_int8_compute: bool, optional + Whether to enable int8 computation. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + if self.dnn_to_bnn_flag: + return_kl = False + + if not enable_int8_compute: # Deprecated. Use this method for reducing model size only. + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + + weight = self.mu_kernel + (self.sigma_weight * self.eps_kernel.data.normal_()) + bias = None + + if self.bias: + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + + out = F.conv_transpose3d(input, weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + else: + eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6. + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_kernel.q_scale()) # Calculate the new scale after multiplying two quantized tensors. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_kernel, new_scale, 0) + new_scale = max(new_scale, self.quantized_mu_weight.q_scale()) # Calculate the new scale after adding two quantized tensors. + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, new_scale, 0) + bias = None + + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, default_scale, default_zero_point, torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(weight, bias, self.stride, + self.padding, self.output_padding, + self.dilation, self.groups) + + out = torch.ops.quantized.conv_transpose3d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point) + + + if return_kl: + return out, 0 # disable kl divergence computing + + return out \ No newline at end of file diff --git a/bayesian_torch/layers/variational_layers/quantize_linear_variational.py b/bayesian_torch/layers/variational_layers/quantize_linear_variational.py new file mode 100644 index 0000000..a12a569 --- /dev/null +++ b/bayesian_torch/layers/variational_layers/quantize_linear_variational.py @@ -0,0 +1,224 @@ +# Copyright (C) 2021 Intel Labs +# +# BSD-3-Clause License +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ====================================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Module, Parameter +from ..base_variational_layer import BaseVariationalLayer_ +import math +from .linear_variational import LinearReparameterization + + + +class QuantizedLinearReparameterization(LinearReparameterization): + def __init__(self, + in_features, + out_features): + """ + + """ + super(QuantizedLinearReparameterization, self).__init__( + in_features, + out_features) + + self.is_dequant = False + self.quant_dict = None + + def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + + def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + + def get_dequantized_tensor(self, x): + dequantized_x = x.dequantize() + + return dequantized_x + + + def quantize(self): + self.quantized_mu_weight = Parameter(self.get_quantized_tensor(self.mu_weight), requires_grad=False) + self.quantized_sigma_weight = Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_weight))), requires_grad=False) + delattr(self, "mu_weight") + delattr(self, "rho_weight") + + self.quantized_mu_bias = self.mu_bias#Parameter(self.get_quantized_tensor(self.mu_bias), requires_grad=False) + self.quantized_sigma_bias = Parameter(torch.log1p(torch.exp(self.rho_bias)), requires_grad=False)#Parameter(self.get_quantized_tensor(torch.log1p(torch.exp(self.rho_bias))), requires_grad=False) + delattr(self, "mu_bias") + delattr(self, "rho_bias") + + def dequantize(self): # Deprecated + self.mu_weight = self.get_dequantized_tensor(self.quantized_mu_weight) + self.sigma_weight = self.get_dequantized_tensor(self.quantized_sigma_weight) + + self.mu_bias = self.get_dequantized_tensor(self.quantized_mu_bias) + self.sigma_bias = self.get_dequantized_tensor(self.quantized_sigma_bias) + return + + def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_scale=0.2, default_zero_point=128, return_kl=True): + """ Forward pass + + Parameters + ---------- + input: tensors + Input tensor. + + enable_int8_compute: bool, optional + Whether to enable int8 computation. + + normal_scale: float, optional + Scale for quantized tensor sampled from normal distribution. + since 99.7% values will lie within 3 standard deviations, the original range is set as 6. + + default_scale: float, optional + Default scale for quantized input tensor and quantized output tensor. + Set to 0.1 by grid search. + + default_zero_point: int, optional + Default zero point for quantized input tensor and quantized output tensor. + Set to 128 for quint8 tensor. + + + + Returns + ---------- + out: tensors + Output tensor. + + KL: float + set to 0 since we diable KL divergence computation in quantized layers. + + + """ + if self.dnn_to_bnn_flag: + return_kl = False + + if self.quant_dict is not None: + eps_weight = torch.quantize_per_tensor(self.eps_weight.data.normal_(), self.quant_dict[0]['scale'], self.quant_dict[0]['zero_point'], torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6. + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_weight, self.quant_dict[1]['scale'], self.quant_dict[1]['zero_point']) + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, self.quant_dict[2]['scale'], self.quant_dict[2]['zero_point']) + bias = None + + ## DO NOT QUANTIZE BIAS!!! + if self.bias: + if self.quantized_sigma_bias is None: # the case that bias comes from bn fusion + bias = self.quantized_mu_bias + else: # original case + bias = self.quantized_mu_bias + (self.quantized_sigma_bias * self.eps_bias.data.normal_()) + + if input.dtype!=torch.quint8: # check if input has been quantized + input = torch.quantize_per_tensor(input, self.quant_dict[3]['scale'], self.quant_dict[3]['zero_point'], torch.quint8) # scale=0.1 by grid search; zero_point=128 for uint8 format + + out = torch.nn.quantized.functional.linear(input, weight, bias, scale=self.quant_dict[4]['scale'], zero_point=self.quant_dict[4]['zero_point']) # input: quint8, weight: qint8, bias: fp32 + out = out.dequantize() + + elif not enable_int8_compute: # Deprecated. Use this method for reducing model size only. + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + weight = self.mu_weight + (self.sigma_weight * self.eps_weight.data.normal_()) + bias = None + if self.sigma_bias is not None: + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + + out = F.linear(input, weight, bias) + + else: + eps_weight = torch.quantize_per_tensor(self.eps_weight.data.normal_(), normal_scale, 0, torch.qint8) + new_scale = (self.quantized_sigma_weight.q_scale())*(eps_weight.q_scale()) + weight = torch.ops.quantized.mul(self.quantized_sigma_weight, eps_weight, new_scale, 0) + new_scale = max(new_scale, self.quantized_mu_weight.q_scale()) + weight = torch.ops.quantized.add(weight, self.quantized_mu_weight, new_scale, 0) + bias = None + + if self.quantized_sigma_bias is not None: + if not self.is_dequant: + self.dequantize() + self.is_dequant = True + bias = self.mu_bias + (self.sigma_bias * self.eps_bias.data.normal_()) + if input.dtype!=torch.quint8: + input = torch.quantize_per_tensor(input, default_scale, default_zero_point, torch.quint8) + + out = torch.nn.quantized.functional.linear(input, weight, bias, scale=default_scale, zero_point=default_zero_point) # input: quint8, weight: qint8, bias: fp32 + out = out.dequantize() + + if return_kl: + return out, 0 # disable kl divergence computing + + return out diff --git a/bayesian_torch/models/bayesian/quantized_resnet_flipout_large.py b/bayesian_torch/models/bayesian/quantized_resnet_flipout_large.py new file mode 100644 index 0000000..61c0dd0 --- /dev/null +++ b/bayesian_torch/models/bayesian/quantized_resnet_flipout_large.py @@ -0,0 +1,277 @@ +''' +Bayesian ResNet for CIFAR10. + +ResNet architecture ref: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from bayesian_torch.layers import QuantizedConv2dFlipout +from bayesian_torch.layers import QuantizedLinearFlipout +from torch.nn.quantized import BatchNorm2d as QuantizedBatchNorm2d +from torch.nn import Identity + +__all__ = [ + 'QResNet', 'qresnet18', 'qresnet34', 'qresnet50', 'qresnet101', 'qresnet152' +] + +def _weights_init(m): + classname = m.__class__.__name__ + if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + + +class LambdaLayer(nn.Module): + def __init__(self, lambd): + super(LambdaLayer, self).__init__() + self.lambd = lambd + + def forward(self, x): + return self.lambd(x) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1, option='A', bias=False): + super(BasicBlock, self).__init__() + self.conv1 = QuantizedConv2dFlipout( + in_channels=in_planes, + out_channels=planes, + kernel_size=3, + stride=stride, + padding=1, + bias=bias) + self.bn1 = QuantizedBatchNorm2d(planes) + self.conv2 = QuantizedConv2dFlipout( + in_channels=planes, + out_channels=planes, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + self.bn2 = QuantizedBatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != planes: + if option == 'A': + """ + For CIFAR10 ResNet paper uses option A. + """ + self.shortcut = LambdaLayer(lambda x: F.pad( + x[:, :, ::2, ::2], + (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0)) + elif option == 'B': + self.shortcut = nn.Sequential( + QuantizedConv2dFlipout( + in_channels=in_planes, + out_channels=self.expansion * planes, + kernel_size=1, + stride=stride, + bias=bias), QuantizedBatchNorm2d(self.expansion * planes)) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = F.relu(out) + out = self.conv2(out) + out = self.bn2(out) + sh = self.shortcut(x.contiguous()).contiguous() + new_scale = max(out.q_scale(), sh.q_scale()) + out = torch.ops.quantized.add(out, sh, new_scale, 0) + # out += self.shortcut(x) + out = F.relu(out) + return out + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, bias=False): + super(Bottleneck, self).__init__() + self.conv1 = QuantizedConv2dFlipout( + in_channels=inplanes, + out_channels=planes, + kernel_size=1, + bias=bias) + self.bn1 =QuantizedBatchNorm2d(planes) + self.conv2 = QuantizedConv2dFlipout( + in_channels=planes, + out_channels=planes, + kernel_size=3, + stride=stride, + padding=1, + bias=bias) + self.bn2 = QuantizedBatchNorm2d(planes) + self.conv3 = QuantizedConv2dFlipout( + in_channels=planes, + out_channels=planes * 4, + kernel_size=1, + bias=bias) + self.bn3 = QuantizedBatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + # out += residual + new_scale = max(out.q_scale(), residual.q_scale()) + out = torch.ops.quantized.add(out, residual, new_scale, 0) + out = self.relu(out) + + return out + +class QResNet(nn.Module): + def __init__(self, block, layers, num_classes=1000, bias=False): + super(QResNet, self).__init__() + self.inplanes = 64 + self.conv1 = QuantizedConv2dFlipout( + in_channels=3, + out_channels=64, + kernel_size=7, + stride=2, + padding=3, + bias=bias) + self.bn1 = QuantizedBatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], bias=bias) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, bias=bias) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, bias=bias) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, bias=bias) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = QuantizedLinearFlipout( + in_features=512 * block.expansion, + out_features=num_classes, + ) + + self.apply(_weights_init) + + def _make_layer(self, block, planes, blocks, stride=1, bias=False): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + QuantizedConv2dFlipout(in_channels=self.inplanes, + out_channels=planes * block.expansion, + kernel_size=1, + stride=stride, + bias=bias), + QuantizedBatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, bias=bias)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, bias=bias)) + + return nn.Sequential(*layers) + + def quant_then_dequant(self, m, fuse_conv_bn=False): ## quantize only; need to rename this function + for name, value in list(m._modules.items()): + if m._modules[name]._modules: + self.quant_then_dequant(m._modules[name], fuse_conv_bn=fuse_conv_bn) + + if "QuantizedConv" in m._modules[name].__class__.__name__: + m._modules[name].quantize() + m._modules[name].quantized_sigma_bias = None ### work around + m._modules[name].dnn_to_bnn_flag = True ## since we don't compute kl in quantized models, this flag will be removed after refactoring + + if "QuantizedLinear" in m._modules[name].__class__.__name__: + m._modules[name].quantize() + m._modules[name].dnn_to_bnn_flag = True ## since we don't compute kl in quantized models, this flag will be removed after refactoring + + if fuse_conv_bn and "BatchNorm2d" in m._modules[name].__class__.__name__: # quite confusing, should be quantizedbatchnorm2d + setattr(m, name, Identity()) + + def forward(self, x): + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + for layer in self.layer1: + x = layer(x) + + for layer in self.layer2: + x = layer(x) + + for layer in self.layer3: + x = layer(x) + + for layer in self.layer4: + x = layer(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def qresnet18(pretrained=False, **kwargs): + model = QResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + return model + + +def qresnet34(pretrained=False, **kwargs): + model = QResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + return model + + +def qresnet50(pretrained=False, **kwargs): + model = QResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + return model + + +def qresnet101(pretrained=False, **kwargs): + model = QResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + return model + + +def qresnet152(pretrained=False, **kwargs): + model = QResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + return model + + + +def test(net): + import numpy as np + total_params = 0 + + for x in filter(lambda p: p.requires_grad, net.parameters()): + total_params += np.prod(x.data.numpy().shape) + print("Total number of params", total_params) + print( + "Total layers", + len( + list( + filter(lambda p: p.requires_grad and len(p.data.size()) > 1, + net.parameters())))) + + +if __name__ == "__main__": + for net_name in __all__: + if net_name.startswith('qresnet'): + print(net_name) + test(globals()[net_name]()) + print() diff --git a/bayesian_torch/models/bayesian/quantized_resnet_variational_large.py b/bayesian_torch/models/bayesian/quantized_resnet_variational_large.py new file mode 100644 index 0000000..6d0a57e --- /dev/null +++ b/bayesian_torch/models/bayesian/quantized_resnet_variational_large.py @@ -0,0 +1,277 @@ +''' +Bayesian ResNet for CIFAR10. + +ResNet architecture ref: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from bayesian_torch.layers import QuantizedConv2dReparameterization +from bayesian_torch.layers import QuantizedLinearReparameterization +from torch.nn.quantized import BatchNorm2d as QuantizedBatchNorm2d +from torch.nn import Identity + +__all__ = [ + 'QResNet', 'qresnet18', 'qresnet34', 'qresnet50', 'qresnet101', 'qresnet152' +] + +def _weights_init(m): + classname = m.__class__.__name__ + if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + + +class LambdaLayer(nn.Module): + def __init__(self, lambd): + super(LambdaLayer, self).__init__() + self.lambd = lambd + + def forward(self, x): + return self.lambd(x) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1, option='A', bias=False): + super(BasicBlock, self).__init__() + self.conv1 = QuantizedConv2dReparameterization( + in_channels=in_planes, + out_channels=planes, + kernel_size=3, + stride=stride, + padding=1, + bias=bias) + self.bn1 = QuantizedBatchNorm2d(planes) + self.conv2 = QuantizedConv2dReparameterization( + in_channels=planes, + out_channels=planes, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + self.bn2 = QuantizedBatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != planes: + if option == 'A': + """ + For CIFAR10 ResNet paper uses option A. + """ + self.shortcut = LambdaLayer(lambda x: F.pad( + x[:, :, ::2, ::2], + (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0)) + elif option == 'B': + self.shortcut = nn.Sequential( + QuantizedConv2dReparameterization( + in_channels=in_planes, + out_channels=self.expansion * planes, + kernel_size=1, + stride=stride, + bias=bias), QuantizedBatchNorm2d(self.expansion * planes)) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = F.relu(out) + out = self.conv2(out) + out = self.bn2(out) + sh = self.shortcut(x.contiguous()).contiguous() + new_scale = max(out.q_scale(), sh.q_scale()) + out = torch.ops.quantized.add(out, sh, new_scale, 0) + # out += self.shortcut(x) + out = F.relu(out) + return out + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, bias=False): + super(Bottleneck, self).__init__() + self.conv1 = QuantizedConv2dReparameterization( + in_channels=inplanes, + out_channels=planes, + kernel_size=1, + bias=bias) + self.bn1 =QuantizedBatchNorm2d(planes) + self.conv2 = QuantizedConv2dReparameterization( + in_channels=planes, + out_channels=planes, + kernel_size=3, + stride=stride, + padding=1, + bias=bias) + self.bn2 = QuantizedBatchNorm2d(planes) + self.conv3 = QuantizedConv2dReparameterization( + in_channels=planes, + out_channels=planes * 4, + kernel_size=1, + bias=bias) + self.bn3 = QuantizedBatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + # out += residual + new_scale = max(out.q_scale(), residual.q_scale()) + out = torch.ops.quantized.add(out, residual, new_scale, 0) + out = self.relu(out) + + return out + +class QResNet(nn.Module): + def __init__(self, block, layers, num_classes=1000, bias=False): + super(QResNet, self).__init__() + self.inplanes = 64 + self.conv1 = QuantizedConv2dReparameterization( + in_channels=3, + out_channels=64, + kernel_size=7, + stride=2, + padding=3, + bias=bias) + self.bn1 = QuantizedBatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], bias=bias) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, bias=bias) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, bias=bias) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, bias=bias) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = QuantizedLinearReparameterization( + in_features=512 * block.expansion, + out_features=num_classes, + ) + + self.apply(_weights_init) + + def _make_layer(self, block, planes, blocks, stride=1, bias=False): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + QuantizedConv2dReparameterization(in_channels=self.inplanes, + out_channels=planes * block.expansion, + kernel_size=1, + stride=stride, + bias=bias), + QuantizedBatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, bias=bias)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, bias=bias)) + + return nn.Sequential(*layers) + + def quant_then_dequant(self, m, fuse_conv_bn=False): ## quantize only; need to rename this function + for name, value in list(m._modules.items()): + if m._modules[name]._modules: + self.quant_then_dequant(m._modules[name], fuse_conv_bn=fuse_conv_bn) + + if "QuantizedConv" in m._modules[name].__class__.__name__: + m._modules[name].quantize() + m._modules[name].quantized_sigma_bias = None ### work around + m._modules[name].dnn_to_bnn_flag = True ## since we don't compute kl in quantized models, this flag will be removed after refactoring + + if "QuantizedLinear" in m._modules[name].__class__.__name__: + m._modules[name].quantize() + m._modules[name].dnn_to_bnn_flag = True ## since we don't compute kl in quantized models, this flag will be removed after refactoring + + if fuse_conv_bn and "BatchNorm2d" in m._modules[name].__class__.__name__: # quite confusing, should be quantizedbatchnorm2d + setattr(m, name, Identity()) + + def forward(self, x): + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + for layer in self.layer1: + x = layer(x) + + for layer in self.layer2: + x = layer(x) + + for layer in self.layer3: + x = layer(x) + + for layer in self.layer4: + x = layer(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def qresnet18(pretrained=False, **kwargs): + model = QResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + return model + + +def qresnet34(pretrained=False, **kwargs): + model = QResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + return model + + +def qresnet50(pretrained=False, **kwargs): + model = QResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + return model + + +def qresnet101(pretrained=False, **kwargs): + model = QResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + return model + + +def qresnet152(pretrained=False, **kwargs): + model = QResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + return model + + + +def test(net): + import numpy as np + total_params = 0 + + for x in filter(lambda p: p.requires_grad, net.parameters()): + total_params += np.prod(x.data.numpy().shape) + print("Total number of params", total_params) + print( + "Total layers", + len( + list( + filter(lambda p: p.requires_grad and len(p.data.size()) > 1, + net.parameters())))) + + +if __name__ == "__main__": + for net_name in __all__: + if net_name.startswith('qresnet'): + print(net_name) + test(globals()[net_name]()) + print() diff --git a/bayesian_torch/models/bayesian/resnet_variational_large.py b/bayesian_torch/models/bayesian/resnet_variational_large.py index bc641d6..e5fb9fd 100644 --- a/bayesian_torch/models/bayesian/resnet_variational_large.py +++ b/bayesian_torch/models/bayesian/resnet_variational_large.py @@ -14,7 +14,7 @@ from bayesian_torch.layers import BatchNorm2dLayer __all__ = [ - 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152' + 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'BasicBlock', 'Bottleneck' ] prior_mu = 0.0 diff --git a/bayesian_torch/models/bnn_to_qbnn.py b/bayesian_torch/models/bnn_to_qbnn.py new file mode 100644 index 0000000..85953cf --- /dev/null +++ b/bayesian_torch/models/bnn_to_qbnn.py @@ -0,0 +1,259 @@ +# Copyright (C) 2021 Intel Labs +# +# BSD-3-Clause License +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Functions related to BNN to QBNN model conversion. +# +# @authors: Jun-Liang Lin +# +# =============================================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import bayesian_torch.layers as bayesian_layers +import torch +import torch.nn as nn +from torch.nn import Identity +from torch.nn.quantized import BatchNorm2d as QBatchNorm2d +from torch.nn import Module, Parameter + + +def get_scale_and_zero_point(self, x, upper_bound=100, target_range=255): + """ An implementation for symmetric quantization + + Parameters + ---------- + x: tensor + Input tensor. + upper_bound: int, optional + Restrict the maximum value of the original tensor (select 100 empirically). + target_range: int, optional + The range of target data type (255 for int8) + + Returns + ---------- + scale: float + + zero_point: int + + """ + # + scale = torch.zeros(1).to(x.device) # initialize + zero_point = torch.zeros(1).to(x.device) # zero point is zero since we only consider symmetric quantization + xmax = torch.clamp(x.abs().max(), 0, upper_bound) # determine and restrict the maximum value (minimum value should be 0 since the absolute value is always non-negative) + scale = xmax*2/target_range # original range divided by target range + return scale, zero_point + +def get_quantized_tensor(self, x, default_scale=0.1): + """ Quantize tensors + + Parameters + ---------- + x: tensors + Input tensor. + + default_scale: float, optional + Default scale for the case that the computed scale is zero. + + + Returns + ---------- + quantized_x: tensors + + + """ + scale, zero_point = self.get_scale_and_zero_point(x) + if scale == 0: + scale = torch.tensor([default_scale]) # avoid zero scale + quantized_x = torch.quantize_per_tensor(x, scale, zero_point, torch.qint8) + + return quantized_x + +def qbnn_linear_layer(d): + layer_type = "Quantized" + d.__class__.__name__ + layer_fn = getattr(bayesian_layers, layer_type) # Get QBNN layer + qbnn_layer = layer_fn( + in_features=d.in_features, + out_features=d.out_features, + ) + qbnn_layer.__dict__.update(d.__dict__) + + if d.quant_prepare: + qbnn_layer.quant_dict = nn.ModuleList() + for qstub in d.qint_quant: + qbnn_layer.quant_dict.append(nn.ParameterDict({'scale': torch.nn.Parameter(qstub.scale.float()), 'zero_point': torch.nn.Parameter(qstub.zero_point.float())})) + qbnn_layer.quant_dict = qbnn_layer.quant_dict[2:] + for qstub in d.quint_quant: + qbnn_layer.quant_dict.append(nn.ParameterDict({'scale': torch.nn.Parameter(qstub.scale.float()), 'zero_point': torch.nn.Parameter(qstub.zero_point.float())})) + + qbnn_layer.quantize() + if d.dnn_to_bnn_flag: + qbnn_layer.dnn_to_bnn_flag = True + return qbnn_layer + +def qbnn_conv_layer(d): + layer_type = "Quantized" + d.__class__.__name__ + layer_fn = getattr(bayesian_layers, layer_type) # Get QBNN layer + qbnn_layer = layer_fn( + in_channels=d.in_channels, + out_channels=d.out_channels, + kernel_size=d.kernel_size, + stride=d.stride, + padding=d.padding, + dilation=d.dilation, + groups=d.groups, + ) + qbnn_layer.__dict__.update(d.__dict__) + + if d.quant_prepare: + qbnn_layer.quant_dict = nn.ModuleList() + for qstub in d.qint_quant: + qbnn_layer.quant_dict.append(nn.ParameterDict({'scale': torch.nn.Parameter(qstub.scale.float()), 'zero_point': torch.nn.Parameter(qstub.zero_point.float())})) + qbnn_layer.quant_dict = qbnn_layer.quant_dict[2:] + for qstub in d.quint_quant: + qbnn_layer.quant_dict.append(nn.ParameterDict({'scale': torch.nn.Parameter(qstub.scale.float()), 'zero_point': torch.nn.Parameter(qstub.zero_point.float())})) + + qbnn_layer.quantize() + if d.dnn_to_bnn_flag: + qbnn_layer.dnn_to_bnn_flag = True + return qbnn_layer + +def qbnn_lstm_layer(d): + layer_type = "Quantized" + d.__class__.__name__ + layer_fn = getattr(bayesian_layers, layer_type) # Get QBNN layer + qbnn_layer = layer_fn( + in_features=d.input_size, + out_features=d.hidden_size, + ) + qbnn_layer.__dict__.update(d.__dict__) + qbnn_layer.quantize() + if d.dnn_to_bnn_flag: + qbnn_layer.dnn_to_bnn_flag = True + return qbnn_layer + +def qbnn_batchnorm2d_layer(d): + layer_fn = QBatchNorm2d # Get QBNN layer + qbnn_layer = layer_fn( + num_features=d.num_features + ) + qbnn_layer.__dict__.update(d.__dict__) + # qbnn_layer.weight = Parameter(get_quantized_tensor(d.weight), requires_grad=False) + # qbnn_layer.bias = Parameter(get_quantized_tensor(d.bias), requires_grad=False) + # qbnn_layer.running_mean = Parameter(get_quantized_tensor(d.running_mean), requires_grad=False) + # qbnn_layer.running_var = Parameter(get_quantized_tensor(d.running_var), requires_grad=False) + # qbnn_layer.scale = Parameter(torch.tensor([0.1]), requires_grad=False) + # qbnn_layer.zero_point = Parameter(torch.tensor([128]), requires_grad=False) + return qbnn_layer + + +# batch norm folding +def batch_norm_folding(conv, bn): + layer_type = "Quantized" + conv.__class__.__name__ + layer_fn = getattr(bayesian_layers, layer_type) # Get QBNN layer + qbnn_layer = layer_fn( + in_channels=conv.in_channels, + out_channels=conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + dilation=conv.dilation, + groups=conv.groups, + ) + qbnn_layer.__dict__.update(conv.__dict__) + qbnn_layer.bn_weight = bn.weight + qbnn_layer.bn_bias = bn.bias + qbnn_layer.bn_running_mean = bn.running_mean + qbnn_layer.bn_running_var = bn.running_var + qbnn_layer.bn_eps = bn.eps + qbnn_layer.quantize() + if conv.dnn_to_bnn_flag: + qbnn_layer.dnn_to_bnn_flag = True + return qbnn_layer + +# replaces linear and conv layers +def bnn_to_qbnn(m, fuse_conv_bn=False): + for name, value in list(m._modules.items()): + if m._modules[name]._modules: + if "Conv" in m._modules[name].__class__.__name__: + setattr(m, name, qbnn_conv_layer(m._modules[name])) + elif "Linear" in m._modules[name].__class__.__name__: + setattr(m, name, qbnn_linear_layer(m._modules[name])) + else: + bnn_to_qbnn(m._modules[name], fuse_conv_bn=fuse_conv_bn) + elif "Linear" in m._modules[name].__class__.__name__: + setattr(m, name, qbnn_linear_layer(m._modules[name])) + elif "LSTM" in m._modules[name].__class__.__name__: + setattr(m, name, qbnn_lstm_layer(m._modules[name])) + else: + if fuse_conv_bn: + if 'conv1' in m._modules.keys() and 'bn1' in m._modules.keys(): + if 'Identity' not in m._modules['bn1'].__class__.__name__: + setattr(m, 'conv1', batch_norm_folding(m._modules['conv1'], m._modules['bn1'])) + setattr(m, 'bn1', Identity()) + if 'conv2' in m._modules.keys() and 'bn2' in m._modules.keys(): + if 'Identity' not in m._modules['bn2'].__class__.__name__: + setattr(m, 'conv2', batch_norm_folding(m._modules['conv2'], m._modules['bn2'])) + setattr(m, 'bn2', Identity()) + if 'conv3' in m._modules.keys() and 'bn3' in m._modules.keys(): + if 'Identity' not in m._modules['bn3'].__class__.__name__: + setattr(m, 'conv3', batch_norm_folding(m._modules['conv3'], m._modules['bn3'])) + setattr(m, 'bn3', Identity()) + if 'downsample' in m._modules.keys(): + if m._modules['downsample'].__class__.__name__=='Sequential' and len(m._modules['downsample'])==2: + if 'Identity' not in m._modules['downsample'][1].__class__.__name__: + m._modules['downsample'][0]=batch_norm_folding(m._modules['downsample'][0], m._modules['downsample'][1]) + m._modules['downsample'][1]=Identity() + else: + if "Conv" in m._modules[name].__class__.__name__: + setattr(m, name, qbnn_conv_layer(m._modules[name])) + + elif "Batch" in m._modules[name].__class__.__name__: + setattr(m, name, qbnn_batchnorm2d_layer(m._modules[name])) + + return + +if __name__ == "__main__": + class FusionTest(nn.Module): + def __init__(self): + super(FusionTest, self).__init__() + self.conv1 = bayesian_layers.Conv2dReparameterization(1,3,2,bias=False) + self.bn1 = nn.BatchNorm2d(3) + def forward(self, x): + x = self.conv1(x)[0] + x = self.bn1(x) + return x + m = FusionTest() + m.conv1.rho_kernel = Parameter(torch.zeros(m.conv1.rho_kernel.shape)-100) + m.eval() + print(m) + input = torch.randn(1,1,3,3) + print(m(input)) + bnn_to_qbnn(m) + print(m) + if input.dtype!=torch.quint8: + input = torch.quantize_per_tensor(input, 0.1, 128, torch.quint8) + print(m(input)) \ No newline at end of file diff --git a/bayesian_torch/quantization/__init__.py b/bayesian_torch/quantization/__init__.py new file mode 100644 index 0000000..91a6e8b --- /dev/null +++ b/bayesian_torch/quantization/__init__.py @@ -0,0 +1,3 @@ +from .quantize import * + +# __all__ = ['prepare', 'convert'] \ No newline at end of file diff --git a/bayesian_torch/quantization/quantize.py b/bayesian_torch/quantization/quantize.py new file mode 100644 index 0000000..967f79a --- /dev/null +++ b/bayesian_torch/quantization/quantize.py @@ -0,0 +1,2 @@ +from bayesian_torch.ao.quantization.quantize import prepare +from bayesian_torch.ao.quantization.quantize import convert \ No newline at end of file