Skip to content
This repository has been archived by the owner on Dec 3, 2024. It is now read-only.

Commit

Permalink
Merge pull request #20 from junliang-lin/main
Browse files Browse the repository at this point in the history
Add support for output padding in flipout layers
  • Loading branch information
ranganathkrishnan authored Jan 17, 2023
2 parents 7bf1a2e + 33fef0a commit aa8e198
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions bayesian_torch/layers/flipout_layers/conv_flipout.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def __init__(self,
padding=0,
dilation=1,
groups=1,
output_padding=0,
prior_mean=0,
prior_variance=1,
posterior_mu_init=0,
Expand Down Expand Up @@ -588,6 +589,7 @@ def __init__(self,
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.output_padding = output_padding
self.dilation = dilation
self.groups = groups
self.bias = bias
Expand Down Expand Up @@ -669,6 +671,7 @@ def forward(self, x, return_kl=True):
bias=self.mu_bias,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)

Expand Down Expand Up @@ -702,6 +705,7 @@ def forward(self, x, return_kl=True):
bias=bias,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups) * sign_output

Expand All @@ -719,6 +723,7 @@ def __init__(self,
kernel_size,
stride=1,
padding=0,
output_padding=0,
dilation=1,
groups=1,
prior_mean=0,
Expand Down Expand Up @@ -752,6 +757,7 @@ def __init__(self,
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.output_padding = output_padding
self.dilation = dilation
self.groups = groups
self.bias = bias
Expand Down Expand Up @@ -837,6 +843,7 @@ def forward(self, x, return_kl=True):
weight=self.mu_kernel,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)

Expand Down Expand Up @@ -870,6 +877,7 @@ def forward(self, x, return_kl=True):
weight=delta_kernel,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups) * sign_output

Expand All @@ -887,6 +895,7 @@ def __init__(self,
kernel_size,
stride=1,
padding=0,
output_padding=0,
dilation=1,
groups=1,
prior_mean=0,
Expand Down Expand Up @@ -920,6 +929,7 @@ def __init__(self,
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.output_padding = output_padding
self.dilation = dilation
self.groups = groups

Expand Down Expand Up @@ -1005,6 +1015,7 @@ def forward(self, x, return_kl=True):
bias=self.mu_bias,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)

Expand Down Expand Up @@ -1037,6 +1048,7 @@ def forward(self, x, return_kl=True):
bias=bias,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups) * sign_output

Expand Down

0 comments on commit aa8e198

Please sign in to comment.