From 86adb6d6fa7dced490f9a24b95e54ebc1c43ea0a Mon Sep 17 00:00:00 2001 From: junliang-lin Date: Wed, 19 Apr 2023 12:25:31 -0400 Subject: [PATCH] fix batchnorm --- bayesian_torch/layers/batchnorm.py | 23 ++++++++++++------- .../bayesian/resnet_variational_large.py | 2 +- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/bayesian_torch/layers/batchnorm.py b/bayesian_torch/layers/batchnorm.py index 145997c..25ab8f3 100644 --- a/bayesian_torch/layers/batchnorm.py +++ b/bayesian_torch/layers/batchnorm.py @@ -54,7 +54,6 @@ def _check_input_dim(self, input): input.dim())) def forward(self, input): - self._check_input_dim(input[0]) exponential_average_factor = 0.0 if self.training and self.track_running_stats: self.num_batches_tracked += 1 @@ -63,13 +62,21 @@ def forward(self, input): else: # use exponential moving average exponential_average_factor = self.momentum - out = F.batch_norm(input[0], self.running_mean, self.running_var, - self.weight, self.bias, self.training - or not self.track_running_stats, - exponential_average_factor, self.eps) - kl = 0 - return out, kl - + if len(input) == 2: + self._check_input_dim(input[0]) + out = F.batch_norm(input[0], self.running_mean, self.running_var, + self.weight, self.bias, self.training + or not self.track_running_stats, + exponential_average_factor, self.eps) + kl = 0 + return out, kl + else: + out = F.batch_norm(input, self.running_mean, self.running_var, + self.weight, self.bias, self.training + or not self.track_running_stats, + exponential_average_factor, self.eps) + return out + class BatchNorm1dLayer(nn.Module): def __init__(self, diff --git a/bayesian_torch/models/bayesian/resnet_variational_large.py b/bayesian_torch/models/bayesian/resnet_variational_large.py index 6fdf561..e5fb9fd 100644 --- a/bayesian_torch/models/bayesian/resnet_variational_large.py +++ b/bayesian_torch/models/bayesian/resnet_variational_large.py @@ -200,7 +200,7 @@ def _make_layer(self, block, planes, blocks, stride=1): posterior_mu_init=posterior_mu_init, posterior_rho_init=posterior_rho_init, bias=False), - nn.BatchNorm2d(planes * block.expansion), + BatchNorm2dLayer(planes * block.expansion), ) layers = []