Skip to content
This repository has been archived by the owner on Dec 3, 2024. It is now read-only.

Commit

Permalink
correcting the order of group and dilation parameters in Conv transpo…
Browse files Browse the repository at this point in the history
…se layers.

Fix issue #21

Signed-off-by: Ranganath Krishnan <[email protected]>
  • Loading branch information
ranganathkrishnan committed Jan 2, 2024
1 parent 1180b87 commit 97ba16a
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 27 deletions.
24 changes: 12 additions & 12 deletions bayesian_torch/layers/flipout_layers/conv_flipout.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,8 +769,8 @@ def forward(self, x, return_kl=True):
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)
groups=self.groups,
dilation=self.dilation)

# sampling perturbation signs
sign_input = x.clone().uniform_(-1, 1).sign()
Expand Down Expand Up @@ -803,8 +803,8 @@ def forward(self, x, return_kl=True):
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)
groups=self.groups,
dilation=self.dilation)
perturbed_outputs = perturbed_outputs_tmp * sign_output
out = outputs + perturbed_outputs

Expand Down Expand Up @@ -968,8 +968,8 @@ def forward(self, x, return_kl=True):
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)
groups=self.groups,
dilation=self.dilation)

# sampling perturbation signs
sign_input = x.clone().uniform_(-1, 1).sign()
Expand Down Expand Up @@ -1002,8 +1002,8 @@ def forward(self, x, return_kl=True):
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)
groups=self.groups,
dilation=self.dilation)
perturbed_outputs = perturbed_outputs_tmp * sign_output
out = outputs + perturbed_outputs

Expand Down Expand Up @@ -1167,8 +1167,8 @@ def forward(self, x, return_kl=True):
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)
groups=self.groups,
dilation=self.dilation)

# sampling perturbation signs
sign_input = x.clone().uniform_(-1, 1).sign()
Expand Down Expand Up @@ -1200,8 +1200,8 @@ def forward(self, x, return_kl=True):
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)
groups=self.groups,
dilation=self.dilation)
perturbed_outputs = perturbed_outputs_tmp * sign_output
out = outputs + perturbed_outputs

Expand Down
12 changes: 6 additions & 6 deletions bayesian_torch/layers/flipout_layers/quantized_conv_flipout.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1

self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(self.quantized_mu_weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

outputs = torch.ops.quantized.conv_transpose1d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)

Expand All @@ -923,7 +923,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1

self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(delta_kernel, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)
perturbed_outputs = torch.ops.quantized.conv_transpose1d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)

perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point)
Expand Down Expand Up @@ -1106,7 +1106,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1

self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(self.quantized_mu_weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

outputs = torch.ops.quantized.conv_transpose2d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)

Expand All @@ -1131,7 +1131,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1

self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(delta_kernel, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)
perturbed_outputs = torch.ops.quantized.conv_transpose2d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)

perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point)
Expand Down Expand Up @@ -1314,7 +1314,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1

self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(self.quantized_mu_weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

outputs = torch.ops.quantized.conv_transpose3d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)

Expand All @@ -1339,7 +1339,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1

self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(delta_kernel, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)
perturbed_outputs = torch.ops.quantized.conv_transpose3d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)

perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point)
Expand Down
6 changes: 3 additions & 3 deletions bayesian_torch/layers/variational_layers/conv_variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def forward(self, input, return_kl=True):

out = F.conv_transpose1d(input, weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

if self.quant_prepare:
# quint8 quantstub
Expand Down Expand Up @@ -894,7 +894,7 @@ def forward(self, input, return_kl=True):

out = F.conv_transpose2d(input, weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

if self.quant_prepare:
# quint8 quantstub
Expand Down Expand Up @@ -1070,7 +1070,7 @@ def forward(self, input, return_kl=True):

out = F.conv_transpose3d(input, weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

if self.quant_prepare:
# quint8 quantstub
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s

out = F.conv_transpose1d(input, weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

else:
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6.
Expand All @@ -1019,7 +1019,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s

self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

out = torch.ops.quantized.conv_transpose1d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point)

Expand Down Expand Up @@ -1227,7 +1227,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s

out = F.conv_transpose2d(input, weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

else:
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6.
Expand All @@ -1250,7 +1250,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s

self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

out = torch.ops.quantized.conv_transpose2d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point)

Expand Down Expand Up @@ -1458,7 +1458,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s

out = F.conv_transpose3d(input, weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

else:
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6.
Expand All @@ -1481,7 +1481,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s

self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

out = torch.ops.quantized.conv_transpose3d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point)

Expand Down

0 comments on commit 97ba16a

Please sign in to comment.