Skip to content
This repository has been archived by the owner on Dec 3, 2024. It is now read-only.

Commit

Permalink
fix batchnorm
Browse files Browse the repository at this point in the history
  • Loading branch information
junliang-lin committed Apr 19, 2023
1 parent c3e47ed commit 86adb6d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
23 changes: 15 additions & 8 deletions bayesian_torch/layers/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def _check_input_dim(self, input):
input.dim()))

def forward(self, input):
self._check_input_dim(input[0])
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
Expand All @@ -63,13 +62,21 @@ def forward(self, input):
else: # use exponential moving average
exponential_average_factor = self.momentum

out = F.batch_norm(input[0], self.running_mean, self.running_var,
self.weight, self.bias, self.training
or not self.track_running_stats,
exponential_average_factor, self.eps)
kl = 0
return out, kl

if len(input) == 2:
self._check_input_dim(input[0])
out = F.batch_norm(input[0], self.running_mean, self.running_var,
self.weight, self.bias, self.training
or not self.track_running_stats,
exponential_average_factor, self.eps)
kl = 0
return out, kl
else:
out = F.batch_norm(input, self.running_mean, self.running_var,
self.weight, self.bias, self.training
or not self.track_running_stats,
exponential_average_factor, self.eps)
return out


class BatchNorm1dLayer(nn.Module):
def __init__(self,
Expand Down
2 changes: 1 addition & 1 deletion bayesian_torch/models/bayesian/resnet_variational_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _make_layer(self, block, planes, blocks, stride=1):
posterior_mu_init=posterior_mu_init,
posterior_rho_init=posterior_rho_init,
bias=False),
nn.BatchNorm2d(planes * block.expansion),
BatchNorm2dLayer(planes * block.expansion),
)

layers = []
Expand Down

0 comments on commit 86adb6d

Please sign in to comment.