Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support grad clip for MMGeneration #269

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions docs/en/tutorials/customize_runtime.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,28 @@ The default optimizer constructor is implemented [here](https://github.com/open-
Tricks not implemented by the optimizer should be implemented through optimizer constructor (e.g., set parameter-wise learning rates) or hooks. We list some common settings that could stabilize the training or accelerate the training. Feel free to create PR, issue for more settings.

- __Use gradient clip to stabilize training__:
Some models need gradient clip to clip the gradients to stabilize the training process. An example is as below:
MMGeneratiion direct optimize parameters in `train_step`, therefore we do not use `OptimizerHook` in `MMCV`.
If you want to use gradient clip to clip the gradients to stabilize the training process, please add config in `train_cfg`. An example is as below:

```python
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
train_cfg = dict(grad_clip=dict(max_norm=35, norm_type=2))
```

If your config inherits the base config which already sets the `optimizer_config`, you might need `_delete_=True` to override the unnecessary settings. See the [config documentation](https://mmgeneration.readthedocs.io/en/latest/config.html) for more details.
If you want to use different `max_norm` for generator and discriminator, you can follow the following example:

```python
train_cfg = dict(
grad_clip=dict(
generator=dict(max_norm=35, norm_type=2),
discriminator=dict(max_norm=10, norm_type=2)))
```

If you only want to apply gradient clip to the specific model, you can follow the following example:

```python
train_cfg = dict(
grad_clip=dict(discriminator=dict(max_norm=10, norm_type=2)))
```

- __Use momentum schedule to accelerate model convergence__:
We support momentum scheduler to modify model's momentum according to learning rate, which could make the model converge in a faster way.
Expand Down
24 changes: 9 additions & 15 deletions mmgen/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mmcv
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import HOOKS, IterBasedRunner, OptimizerHook, build_runner
from mmcv.runner import HOOKS, IterBasedRunner, build_runner
from mmcv.runner import set_random_seed as set_random_seed_mmcv
from mmcv.utils import build_from_cfg

Expand Down Expand Up @@ -127,29 +127,23 @@ def train_model(model,
runner.timestamp = timestamp

# fp16 setting
fp16_cfg = cfg.get('fp16', None)
assert cfg.get('fp16', None) is None, 'Fp16 has not been supported.'

# In GANs, we can directly optimize parameter in `train_step` function.
if cfg.get('optimizer_cfg', None) is None:
optimizer_config = None
elif fp16_cfg is not None:
raise NotImplementedError('Fp16 has not been supported.')
# optimizer_config = Fp16OptimizerHook(
# **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
# default to use OptimizerHook
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# Therefore, we do not support OptimizerHook.
assert cfg.get('optimizer_config', None) is None, (
'MMGen direct optimize parameters in the `train_step` function. If '
'you want to apply gradient clip operation, please add config to '
'`train_cfg`')

# update `out_dir` in ckpt hook
if cfg.checkpoint_config is not None:
cfg.checkpoint_config['out_dir'] = os.path.join(
cfg.work_dir, cfg.checkpoint_config.get('out_dir', 'ckpt'))

# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
runner.register_training_hooks(cfg.lr_config, None, cfg.checkpoint_config,
cfg.log_config,
cfg.get('momentum_config', None))

# # DistSamplerSeedHook should be used with EpochBasedRunner
Expand Down
31 changes: 31 additions & 0 deletions mmgen/models/gans/base_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.utils import is_list_of
from torch.nn.utils import clip_grad


class BaseGAN(nn.Module, metaclass=ABCMeta):
Expand Down Expand Up @@ -96,6 +98,35 @@ def _get_gen_loss(self, outputs_dict):

return loss, log_var

def clip_grads(self, model, log_vars=None):
"""Apply gradient clip for the input model.
Args:
model (str): The name of the input model.
log_vars (dict, optional): The dict that contains variables to be
logged. Defaults to None.

Returns:
float: Total norm value of the model.
"""
if not hasattr(self, 'grad_clip') or self.grad_clip is None:
return None
if is_list_of(list(self.grad_clip.values()), dict):
# use different grad clip config to different models
if model not in self.grad_clip:
return None
clip_args = self.grad_clip[model]
else:
# use same grad clip config for all models
clip_args = self.grad_clip
params = getattr(self, model).parameters()
params = list(
filter(lambda p: p.requires_grad and p.grad is not None, params))
if len(params) > 0:
total_norm = clip_grad.clip_grad_norm_(params, **clip_args).item()
if log_vars is not None:
log_vars[f'grad_norm_{model}'] = total_norm
return total_norm

@abstractmethod
def train_step(self, data, optimizer, ddp_reducer=None):
"""The iteration step during training.
Expand Down
6 changes: 5 additions & 1 deletion mmgen/models/gans/basic_conditional_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _parse_train_cfg(self):
if self.use_ema:
# use deepcopy to guarantee the consistency
self.generator_ema = deepcopy(self.generator)
self.grad_clip = self.train_cfg.get('grad_clip', None)

def _parse_test_cfg(self):
"""Parsing test config and set some attributes for testing."""
Expand Down Expand Up @@ -231,10 +232,11 @@ def train_step(self,
if (curr_iter + 1) % self.batch_accumulation_steps == 0:
if loss_scaler:
loss_scaler.unscale_(optimizer['discriminator'])
# note that we do not contain clip_grad procedure
self.clip_grads('discriminator', log_vars_disc)
loss_scaler.step(optimizer['discriminator'])
# loss_scaler.update will be called in runner.train()
else:
self.clip_grads('discriminator', log_vars_disc)
optimizer['discriminator'].step()

# skip generator training if only train discriminator for current
Expand Down Expand Up @@ -298,9 +300,11 @@ def train_step(self,
if loss_scaler:
loss_scaler.unscale_(optimizer['generator'])
# note that we do not contain clip_grad procedure
self.clip_grads('generator', log_vars_g)
loss_scaler.step(optimizer['generator'])
# loss_scaler.update will be called in runner.train()
else:
self.clip_grads('generator', log_vars_g)
optimizer['generator'].step()

log_vars = {}
Expand Down
7 changes: 5 additions & 2 deletions mmgen/models/gans/static_unconditional_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _parse_train_cfg(self):
self.generator_ema = deepcopy(self.generator)

self.real_img_key = self.train_cfg.get('real_img_key', 'real_img')
self.grad_clip = self.train_cfg.get('grad_clip', None)

def _parse_test_cfg(self):
"""Parsing test config and set some attributes for testing."""
Expand Down Expand Up @@ -207,10 +208,11 @@ def train_step(self,

if loss_scaler:
loss_scaler.unscale_(optimizer['discriminator'])
# note that we do not contain clip_grad procedure
self.clip_grads('discriminator', log_vars_disc)
loss_scaler.step(optimizer['discriminator'])
# loss_scaler.update will be called in runner.train()
else:
self.clip_grads('discriminator', log_vars_disc)
optimizer['discriminator'].step()

# skip generator training if only train discriminator for current
Expand Down Expand Up @@ -264,10 +266,11 @@ def train_step(self,

if loss_scaler:
loss_scaler.unscale_(optimizer['generator'])
# note that we do not contain clip_grad procedure
self.clip_grads('generator', log_vars_g)
loss_scaler.step(optimizer['generator'])
# loss_scaler.update will be called in runner.train()
else:
self.clip_grads('generator', log_vars_g)
optimizer['generator'].step()

log_vars = {}
Expand Down
102 changes: 98 additions & 4 deletions tests/test_models/test_basic_conditional_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ def setup_class(cls):
generator=dict(
type='SNGANGenerator',
output_scale=32,
base_channels=256,
base_channels=16,
num_classes=10),
discriminator=dict(
type='ProjDiscriminator',
input_scale=32,
base_channels=128,
base_channels=16,
num_classes=10),
gan_loss=dict(type='GANLoss', gan_type='hinge'),
disc_auxiliary_loss=None,
Expand All @@ -34,14 +34,14 @@ def setup_class(cls):
type='SAGANGenerator',
output_scale=32,
num_classes=10,
base_channels=256,
base_channels=16,
attention_after_nth_block=2,
with_spectral_norm=True)
cls.disc_cfg = dict(
type='SAGANDiscriminator',
input_scale=32,
num_classes=10,
base_channels=128,
base_channels=16,
attention_after_nth_block=1)
cls.gan_loss = dict(type='GANLoss', gan_type='hinge')
cls.disc_auxiliary_loss = [
Expand Down Expand Up @@ -137,6 +137,53 @@ def test_default_dcgan_model_cpu(self):
assert isinstance(sagan.disc_auxiliary_losses, nn.ModuleList)
assert isinstance(sagan.gen_auxiliary_losses, nn.ModuleList)

# test grad clip
train_cfg = dict(grad_clip=dict(max_norm=10, norm_type=2))
sagan = BasicConditionalGAN(
self.generator_cfg,
self.disc_cfg,
self.gan_loss,
self.disc_auxiliary_loss,
train_cfg=train_cfg)
# test train step
data = torch.randn((2, 3, 32, 32))
data_input = dict(img=data, gt_label=label)
optimizer_g = torch.optim.SGD(sngan.generator.parameters(), lr=0.01)
optimizer_d = torch.optim.SGD(
sngan.discriminator.parameters(), lr=0.01)
optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d)

model_outputs = sagan.train_step(data_input, optim_dict)
assert 'results' in model_outputs
assert 'log_vars' in model_outputs
assert model_outputs['num_samples'] == 2
assert 'grad_norm_generator' in model_outputs['log_vars']
assert 'grad_norm_discriminator' in model_outputs['log_vars']

# test specific grad clip
train_cfg = dict(
grad_clip=dict(discriminator=dict(max_norm=10, norm_type=2)))
sagan = BasicConditionalGAN(
self.generator_cfg,
self.disc_cfg,
self.gan_loss,
self.disc_auxiliary_loss,
train_cfg=train_cfg)
# test train step
data = torch.randn((2, 3, 32, 32))
data_input = dict(img=data, gt_label=label)
optimizer_g = torch.optim.SGD(sngan.generator.parameters(), lr=0.01)
optimizer_d = torch.optim.SGD(
sngan.discriminator.parameters(), lr=0.01)
optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d)

model_outputs = sagan.train_step(data_input, optim_dict)
assert 'results' in model_outputs
assert 'log_vars' in model_outputs
assert model_outputs['num_samples'] == 2
assert 'grad_norm_generator' not in model_outputs['log_vars']
assert 'grad_norm_discriminator' in model_outputs['log_vars']

@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_default_dcgan_model_cuda(self):
sngan = build_model(self.default_config).cuda()
Expand Down Expand Up @@ -179,3 +226,50 @@ def test_default_dcgan_model_cuda(self):
data_input, optim_dict, running_status=dict(iteration=1))
assert 'loss_disc_fake' in model_outputs['log_vars']
assert 'loss_disc_fake_g' in model_outputs['log_vars']

# test grad clip
train_cfg = dict(grad_clip=dict(max_norm=10, norm_type=2))
sagan = BasicConditionalGAN(
self.generator_cfg,
self.disc_cfg,
self.gan_loss,
self.disc_auxiliary_loss,
train_cfg=train_cfg).cuda()
# test train step
data = torch.randn((2, 3, 32, 32)).cuda()
data_input = dict(img=data, gt_label=label)
optimizer_g = torch.optim.SGD(sngan.generator.parameters(), lr=0.01)
optimizer_d = torch.optim.SGD(
sngan.discriminator.parameters(), lr=0.01)
optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d)

model_outputs = sagan.train_step(data_input, optim_dict)
assert 'results' in model_outputs
assert 'log_vars' in model_outputs
assert model_outputs['num_samples'] == 2
assert 'grad_norm_generator' in model_outputs['log_vars']
assert 'grad_norm_discriminator' in model_outputs['log_vars']

# test specific grad clip
train_cfg = dict(
grad_clip=dict(discriminator=dict(max_norm=10, norm_type=2)))
sagan = BasicConditionalGAN(
self.generator_cfg,
self.disc_cfg,
self.gan_loss,
self.disc_auxiliary_loss,
train_cfg=train_cfg).cuda()
# test train step
data = torch.randn((2, 3, 32, 32)).cuda()
data_input = dict(img=data, gt_label=label)
optimizer_g = torch.optim.SGD(sngan.generator.parameters(), lr=0.01)
optimizer_d = torch.optim.SGD(
sngan.discriminator.parameters(), lr=0.01)
optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d)

model_outputs = sagan.train_step(data_input, optim_dict)
assert 'results' in model_outputs
assert 'log_vars' in model_outputs
assert model_outputs['num_samples'] == 2
assert 'grad_norm_generator' not in model_outputs['log_vars']
assert 'grad_norm_discriminator' in model_outputs['log_vars']
Loading