Skip to content
This repository has been archived by the owner on Dec 3, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
junliang-lin authored Apr 26, 2023
2 parents 86adb6d + aa8e198 commit d9a26a5
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 10 deletions.
2 changes: 2 additions & 0 deletions bayesian_torch/layers/flipout_layers/conv_flipout.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from torch.quantization.observer import HistogramObserver, PerChannelMinMaxObserver, MinMaxObserver
from torch.quantization.qconfig import QConfig


from torch.distributions.normal import Normal
from torch.distributions.uniform import Uniform

Expand Down Expand Up @@ -419,6 +420,7 @@ def forward(self, x, return_kl=True):
return out



class Conv3dFlipout(BaseVariationalLayer_):
def __init__(self,
in_channels,
Expand Down
1 change: 1 addition & 0 deletions bayesian_torch/layers/flipout_layers/linear_flipout.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,4 @@ def forward(self, x, return_kl=True):
if return_kl:
return out, kl
return out

11 changes: 1 addition & 10 deletions bayesian_torch/layers/variational_layers/conv_variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def forward(self, input, return_kl=True):
else:
kl = kl_weight
return out, kl

return out


Expand Down Expand Up @@ -973,12 +973,3 @@ def forward(self, input, return_kl=True):

return out

if __name__=="__main__":
m = Conv2dReparameterization(3,3,3)
m.eval()
m.prepare()
m.qconfig = torch.quantization.get_default_qconfig("fbgemm")
mp = torch.quantization.prepare(m)
input = torch.randn(3,3,4,4)
mp(input)
mq = torch.quantization.convert(mp)
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def forward(self, input, return_kl=True):
tmp_result = sigma_weight * eps_weight
weight = self.mu_weight + tmp_result


if return_kl:
kl_weight = self.kl_div(self.mu_weight, sigma_weight,
self.prior_weight_mu, self.prior_weight_sigma)
Expand Down

0 comments on commit d9a26a5

Please sign in to comment.