Skip to content

Commit

Permalink
Simulated BN fold module changes
Browse files Browse the repository at this point in the history
* Support case where BN module has no learnable parameters
  (affine == False)
* Support conv1d and conv3d
  • Loading branch information
guyjacob committed Jun 23, 2019
1 parent b60a33e commit 8424065
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 40 deletions.
86 changes: 56 additions & 30 deletions distiller/quantization/sim_bn_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
FREEZE_BN_DELAY_DEFAULT = 200000


_conv_meta = {'conv1d': (1, F.conv1d),
'conv2d': (2, F.conv2d),
'conv3d': (3, F.conv3d)}


def _broadcast_correction_factor(c, broadcast_to_shape):
"""
Returns a view of `c` which is broadcastable with shape `broadcast_to_shape`.
Expand All @@ -22,17 +27,16 @@ def __init__(self, param_module, bn, freeze_bn_delay=FREEZE_BN_DELAY_DEFAULT, pa
"""
Wrapper for simulated folding of BatchNorm into convolution / linear layers during training
Args:
param_module (nn.Linear or nn.Conv2d): the wrapped parameter layer
bn (nn.BatchNorm1d or nn.BatchNorm2d): batch normalization
freeze_bn_delay (int): number of steps before freezing the batchnorm running stats
param_module (nn.Linear or nn.Conv1d or nn.Conv2d or nn.Conv3d): the wrapped parameter module
bn (nn.BatchNorm1d or nn.BatchNorm2d or nn.BatchNorm3d): batch normalization module
freeze_bn_delay (int): number of steps before freezing the batch-norm running stats
param_quantization_fn (function): function to be used for weight/bias quantization
Note:
The quantized version was implemented according to https://arxiv.org/pdf/1806.08342.pdf Section 3.2.2.
"""
SimulatedFoldedBatchNorm.verify_module_types(param_module, bn)
if not bn.track_running_stats or not bn.affine:
raise ValueError("Simulated BN folding is only supported for BatchNorm which tracks running stats with"
"affine weights.")
if not bn.track_running_stats:
raise ValueError("Simulated BN folding is only supported for BatchNorm which tracks running stats")
super(SimulatedFoldedBatchNorm, self).__init__()
self.param_module = param_module
self.bn = bn
Expand All @@ -43,19 +47,30 @@ def __init__(self, param_module, bn, freeze_bn_delay=FREEZE_BN_DELAY_DEFAULT, pa
if isinstance(param_module, nn.Linear):
self.param_forward_fn = self._linear_layer_forward
self.param_module_type = "fc"
else:
self.param_forward_fn = self._conv2d_layer_forward
elif isinstance(param_module, nn.Conv1d):
self.param_forward_fn = self._conv_layer_forward
self.param_module_type = "conv1d"
elif isinstance(param_module, nn.Conv2d):
self.param_forward_fn = self._conv_layer_forward
self.param_module_type = "conv2d"
else:
self.param_forward_fn = self._conv_layer_forward
self.param_module_type = "conv3d"

@staticmethod
def verify_module_types(param_module, bn):
if not isinstance(param_module, (nn.Linear, nn.Conv2d)) \
and not isinstance(bn, (nn.BatchNorm1d, nn.BatchNorm2d)):
raise TypeError("Only supporting fusing nn.BatchNorm1d/nn.BatchNorm2d into nn.Linear/nn.Conv2d.")
if isinstance(param_module, nn.Linear) and isinstance(bn, nn.BatchNorm2d):
raise TypeError("nn.Linear layer has to be followed by a nn.BatchNorm1d layer.")
if isinstance(param_module, nn.Conv2d) and isinstance(bn, nn.BatchNorm1d):
raise TypeError("nn.Con2d layer has to be followed by a nn.BatchNorm2d layer.")
foldable_seqs = [((nn.Linear, nn.Conv1d), nn.BatchNorm1d),
(nn.Conv2d, nn.BatchNorm2d),
(nn.Conv3d, nn.BatchNorm3d)]
error_msg = "Can't fold sequence of {} --> {}. ".format(param_module.__class__.__name__, bn.__class__.__name__)
for seq in foldable_seqs:
if isinstance(param_module, seq[0]):
if not isinstance(bn, seq[1]):
raise TypeError(error_msg + "{} must be followed by {}".
format(param_module.__class__.__name__, seq[1].__name__))
return
raise TypeError(error_msg + "Only Conv/Linear modules followed by BatchNorm modules allowed"
.format(param_module.__class__.__name__, bn.__class__.__name__))

def forward(self, x):
"""
Expand All @@ -71,7 +86,7 @@ def forward(self, x):
= (x*W -E(x*W)) * gamma / std(x*W) + beta
"""
if not self.frozen:
w, b, gamma, beta = self.param_module.weight, self.param_module.bias, self.bn.weight, self.bn.bias
w, b, gamma, beta = self._get_all_parameters()
if self.training:
batch_mean, batch_var = self.batch_stats(self.param_forward_fn(x, w), b)
recip_sigma_batch = torch.rsqrt(batch_var + self.bn.eps)
Expand Down Expand Up @@ -104,7 +119,7 @@ def broadcast_correction(self, c: torch.Tensor):
"""
Broadcasts a correction factor to the output for elementwise operations.
"""
expected_output_dim = 2 if self.param_module_type == "fc" else 4
expected_output_dim = 2 if self.param_module_type == "fc" else _conv_meta[self.param_module_type][0] + 2
view_fillers_dim = expected_output_dim - c.dim() - 1
view_filler = (1,) * view_fillers_dim
expected_view_shape = c.shape + view_filler
Expand All @@ -116,7 +131,7 @@ def broadcast_correction_weight(self, c: torch.Tensor):
"""
if c.dim() != 1:
raise ValueError("Correction factor needs to have a single dimension")
expected_weight_dim = 2 if self.param_module_type == "fc" else 4
expected_weight_dim = 2 if self.param_module_type == "fc" else _conv_meta[self.param_module_type][0] + 2
view_fillers_dim = expected_weight_dim - c.dim()
view_filler = (1,) * view_fillers_dim
expected_view_shape = c.shape + view_filler
Expand Down Expand Up @@ -185,20 +200,24 @@ def batch_stats(self, x, bias=None):
def _linear_layer_forward(self, input, w, b=None):
return F.linear(input, w, b)

def _conv2d_layer_forward(self, input, w, b=None):
# We copy the code from the Conv2d forward, but plug in our weights.
conv = self.param_module # type: nn.Conv2d
if conv.__dict__.get('padding_mode', None) == 'circular': # This attribute doesn't exist yet in pytorch 1.0.1
expanded_padding = [(conv.padding[1] + 1) // 2, conv.padding[1] // 2,
(conv.padding[0] + 1) // 2, conv.padding[0] // 2]
return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
w, b, conv.stride,
(0, 0), conv.dilation, conv.groups)
return F.conv2d(input, w, b, conv.stride,
conv.padding, conv.dilation, conv.groups)
def _conv_layer_forward(self, input, w, b=None):
# We implement according to Conv1/2/3d.forward(), but plug in our weights
conv = self.param_module
ndims, func = _conv_meta[self.param_module_type]

# 'circular' padding doesn't exist pre-pytorch 1.1.0
if getattr(conv, 'padding_mode', None) == 'circular':
expanded_padding = []
for pad_idx in reversed(range(ndims)):
expanded_padding.extend([(conv.padding[pad_idx] + 1) // 2, conv.padding[pad_idx] // 2])
return func(F.pad(input, expanded_padding, mode='circular'),
w, b, conv.stride,
(0,) * ndims, conv.dilation, conv.groups)
return func(input, w, b, conv.stride,
conv.padding, conv.dilation, conv.groups)

def freeze(self):
w, b, gamma, beta = self.param_module.weight, self.param_module.bias, self.bn.weight, self.bn.bias
w, b, gamma, beta = self._get_all_parameters()
with torch.no_grad():
recip_sigma_running = torch.rsqrt(self.bn.running_var + self.bn.eps)
w.mul_(self.broadcast_correction_weight(gamma * recip_sigma_running))
Expand All @@ -209,3 +228,10 @@ def freeze(self):
else:
self.param_module.bias = nn.Parameter(bias_corrected)
self.frozen = True

def _get_all_parameters(self):
w, b, gamma, beta = self.param_module.weight, self.param_module.bias, self.bn.weight, self.bn.bias
if not self.bn.affine:
gamma = 1.
beta = 0.
return w, b, gamma, beta
69 changes: 59 additions & 10 deletions tests/test_sim_bn_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


@pytest.mark.parametrize(
"m1, m2",
[
(nn.ReLU(), nn.BatchNorm1d(5)),
(nn.Conv1d(1, 2, 3), nn.ReLU()),
(nn.Conv1d(1, 2, 3), nn.BatchNorm2d(2)),
(nn.Conv2d(1, 2, 3), nn.BatchNorm3d(2)),
(nn.Conv3d(1, 2, 3), nn.BatchNorm2d(2)),
(nn.Linear(3, 5), nn.BatchNorm2d(5))
]
)
def test_simulated_bn_fold_bad_sequences(m1, m2):
with pytest.raises(TypeError):
SimulatedFoldedBatchNorm(m1, m2)


@pytest.fixture(params=[False, True], ids=['bias_off', 'bias_on'])
def has_bias(request):
return request.param
Expand All @@ -24,6 +40,11 @@ def momentum(request):
return request.param


@pytest.fixture(params=[True, False], ids=['affine_on', 'affine_off'])
def affine(request):
return request.param


@pytest.mark.parametrize(
"batch_size, input_size, output_size",
[
Expand All @@ -32,12 +53,25 @@ def momentum(request):
(256, 128, 1024)
]
)
def test_simulated_bn_fold_fc(has_bias, batch_size, input_size, output_size, momentum):
def test_simulated_bn_fold_fc(has_bias, batch_size, input_size, output_size, momentum, affine):
distiller.set_deterministic(1234)
linear = nn.Linear(input_size, output_size, bias=has_bias)
bn = nn.BatchNorm1d(output_size, momentum=momentum)
bn = nn.BatchNorm1d(output_size, momentum=momentum, affine=affine)
run_simulated_bn_fold_test(linear, bn, (batch_size, input_size), has_bias)



@pytest.mark.parametrize(
"batch_size, input_c, output_c, l, kernel_size",
[
(50, 3, 100, 80, 10),
]
)
def test_simulated_bn_fold_conv1d(has_bias, batch_size, input_c, output_c, l, kernel_size, momentum, affine):
distiller.set_deterministic(1234)
conv1d = nn.Conv1d(input_c, output_c, kernel_size, bias=has_bias)
bn = nn.BatchNorm1d(output_c, momentum=momentum, affine=affine)
run_simulated_bn_fold_test(conv1d, bn, (batch_size, input_c, l), has_bias)


@pytest.mark.parametrize(
"batch_size, input_c, output_c, h, w, kernel_size",
Expand All @@ -47,13 +81,26 @@ def test_simulated_bn_fold_fc(has_bias, batch_size, input_size, output_size, mom
(256, 3, 64, 28, 28, 7),
]
)
def test_simulated_bn_fold_conv(has_bias, batch_size, input_c, output_c, h, w, kernel_size, momentum):
def test_simulated_bn_fold_conv2d(has_bias, batch_size, input_c, output_c, h, w, kernel_size, momentum, affine):
distiller.set_deterministic(1234)
conv2d = nn.Conv2d(input_c, output_c, kernel_size, bias=has_bias)
bn = nn.BatchNorm2d(output_c, momentum=momentum)
bn = nn.BatchNorm2d(output_c, momentum=momentum, affine=affine)
run_simulated_bn_fold_test(conv2d, bn, (batch_size, input_c, h, w), has_bias)


@pytest.mark.parametrize(
"batch_size, input_c, output_c, h, w, d, kernel_size",
[
(2, 2, 3, 64, 64, 9, 3),
]
)
def test_simulated_bn_fold_conv3d(has_bias, batch_size, input_c, output_c, h, w, d, kernel_size, momentum, affine):
distiller.set_deterministic(1234)
conv3d = nn.Conv3d(input_c, output_c, kernel_size, bias=has_bias)
bn = nn.BatchNorm3d(output_c, momentum=momentum, affine=affine)
run_simulated_bn_fold_test(conv3d, bn, (batch_size, input_c, h, w, d), has_bias)


def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias):
folded = SimulatedFoldedBatchNorm(deepcopy(param_layer), deepcopy(bn_layer), param_quantization_fn=None)
unfolded = nn.Sequential(param_layer, bn_layer)
Expand Down Expand Up @@ -82,22 +129,24 @@ def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias):
loss_unfolded.backward()

# check the gradients:
assert_allclose(unfolded[0].weight.grad, folded.param_module.weight.grad)
assert_allclose(unfolded[0].weight.grad, folded.param_module.weight.grad, RTOL, ATOL)
if has_bias:
# The bias of the linear layer doesn't participate in the calculation!
# for more details - refer to `FusedLinearBatchNorm.forward`
assert folded.param_module.bias.grad is None
assert_allclose(unfolded[1].weight.grad, folded.bn.weight.grad)
assert_allclose(unfolded[1].bias.grad, folded.bn.bias.grad)
if bn_layer.affine:
assert_allclose(unfolded[1].weight.grad, folded.bn.weight.grad, RTOL, ATOL)
assert_allclose(unfolded[1].bias.grad, folded.bn.bias.grad, RTOL, ATOL)

# make a step:
optimizer_unfolded.step()
optimizer_folded.step()

# check updated weights (we skip the linear bias)
assert_allclose(unfolded[0].weight, folded.param_module.weight, RTOL, ATOL)
assert_allclose(unfolded[1].weight, folded.bn.weight, RTOL, ATOL)
assert_allclose(unfolded[1].bias, folded.bn.bias, RTOL, ATOL)
if bn_layer.affine:
assert_allclose(unfolded[1].weight, folded.bn.weight, RTOL, ATOL)
assert_allclose(unfolded[1].bias, folded.bn.bias, RTOL, ATOL)
assert_allclose(unfolded[1].running_mean, folded.bn.running_mean, RTOL, ATOL)
assert_allclose(unfolded[1].running_var, folded.bn.running_var, RTOL, ATOL)

Expand Down

0 comments on commit 8424065

Please sign in to comment.