diff --git a/docs/en/tutorials/customize_runtime.md b/docs/en/tutorials/customize_runtime.md index 7e863a82a..9b5d21746 100644 --- a/docs/en/tutorials/customize_runtime.md +++ b/docs/en/tutorials/customize_runtime.md @@ -106,14 +106,28 @@ The default optimizer constructor is implemented [here](https://github.com/open- Tricks not implemented by the optimizer should be implemented through optimizer constructor (e.g., set parameter-wise learning rates) or hooks. We list some common settings that could stabilize the training or accelerate the training. Feel free to create PR, issue for more settings. - __Use gradient clip to stabilize training__: - Some models need gradient clip to clip the gradients to stabilize the training process. An example is as below: + MMGeneratiion direct optimize parameters in `train_step`, therefore we do not use `OptimizerHook` in `MMCV`. + If you want to use gradient clip to clip the gradients to stabilize the training process, please add config in `train_cfg`. An example is as below: ```python - optimizer_config = dict( - _delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) + train_cfg = dict(grad_clip=dict(max_norm=35, norm_type=2)) ``` - If your config inherits the base config which already sets the `optimizer_config`, you might need `_delete_=True` to override the unnecessary settings. See the [config documentation](https://mmgeneration.readthedocs.io/en/latest/config.html) for more details. + If you want to use different `max_norm` for generator and discriminator, you can follow the following example: + + ```python + train_cfg = dict( + grad_clip=dict( + generator=dict(max_norm=35, norm_type=2), + discriminator=dict(max_norm=10, norm_type=2))) + ``` + + If you only want to apply gradient clip to the specific model, you can follow the following example: + + ```python + train_cfg = dict( + grad_clip=dict(discriminator=dict(max_norm=10, norm_type=2))) + ``` - __Use momentum schedule to accelerate model convergence__: We support momentum scheduler to modify model's momentum according to learning rate, which could make the model converge in a faster way. diff --git a/mmgen/apis/train.py b/mmgen/apis/train.py index 685b34f31..29dfa3d32 100644 --- a/mmgen/apis/train.py +++ b/mmgen/apis/train.py @@ -5,7 +5,7 @@ import mmcv import torch from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import HOOKS, IterBasedRunner, OptimizerHook, build_runner +from mmcv.runner import HOOKS, IterBasedRunner, build_runner from mmcv.runner import set_random_seed as set_random_seed_mmcv from mmcv.utils import build_from_cfg @@ -127,20 +127,14 @@ def train_model(model, runner.timestamp = timestamp # fp16 setting - fp16_cfg = cfg.get('fp16', None) + assert cfg.get('fp16', None) is None, 'Fp16 has not been supported.' # In GANs, we can directly optimize parameter in `train_step` function. - if cfg.get('optimizer_cfg', None) is None: - optimizer_config = None - elif fp16_cfg is not None: - raise NotImplementedError('Fp16 has not been supported.') - # optimizer_config = Fp16OptimizerHook( - # **cfg.optimizer_config, **fp16_cfg, distributed=distributed) - # default to use OptimizerHook - elif distributed and 'type' not in cfg.optimizer_config: - optimizer_config = OptimizerHook(**cfg.optimizer_config) - else: - optimizer_config = cfg.optimizer_config + # Therefore, we do not support OptimizerHook. + assert cfg.get('optimizer_config', None) is None, ( + 'MMGen direct optimize parameters in the `train_step` function. If ' + 'you want to apply gradient clip operation, please add config to ' + '`train_cfg`') # update `out_dir` in ckpt hook if cfg.checkpoint_config is not None: @@ -148,8 +142,8 @@ def train_model(model, cfg.work_dir, cfg.checkpoint_config.get('out_dir', 'ckpt')) # register hooks - runner.register_training_hooks(cfg.lr_config, optimizer_config, - cfg.checkpoint_config, cfg.log_config, + runner.register_training_hooks(cfg.lr_config, None, cfg.checkpoint_config, + cfg.log_config, cfg.get('momentum_config', None)) # # DistSamplerSeedHook should be used with EpochBasedRunner diff --git a/mmgen/models/gans/base_gan.py b/mmgen/models/gans/base_gan.py index cd10c2d01..e0fece49f 100644 --- a/mmgen/models/gans/base_gan.py +++ b/mmgen/models/gans/base_gan.py @@ -5,6 +5,8 @@ import torch import torch.distributed as dist import torch.nn as nn +from mmcv.utils import is_list_of +from torch.nn.utils import clip_grad class BaseGAN(nn.Module, metaclass=ABCMeta): @@ -96,6 +98,35 @@ def _get_gen_loss(self, outputs_dict): return loss, log_var + def clip_grads(self, model, log_vars=None): + """Apply gradient clip for the input model. + Args: + model (str): The name of the input model. + log_vars (dict, optional): The dict that contains variables to be + logged. Defaults to None. + + Returns: + float: Total norm value of the model. + """ + if not hasattr(self, 'grad_clip') or self.grad_clip is None: + return None + if is_list_of(list(self.grad_clip.values()), dict): + # use different grad clip config to different models + if model not in self.grad_clip: + return None + clip_args = self.grad_clip[model] + else: + # use same grad clip config for all models + clip_args = self.grad_clip + params = getattr(self, model).parameters() + params = list( + filter(lambda p: p.requires_grad and p.grad is not None, params)) + if len(params) > 0: + total_norm = clip_grad.clip_grad_norm_(params, **clip_args).item() + if log_vars is not None: + log_vars[f'grad_norm_{model}'] = total_norm + return total_norm + @abstractmethod def train_step(self, data, optimizer, ddp_reducer=None): """The iteration step during training. diff --git a/mmgen/models/gans/basic_conditional_gan.py b/mmgen/models/gans/basic_conditional_gan.py index e07154c44..88e838024 100644 --- a/mmgen/models/gans/basic_conditional_gan.py +++ b/mmgen/models/gans/basic_conditional_gan.py @@ -111,6 +111,7 @@ def _parse_train_cfg(self): if self.use_ema: # use deepcopy to guarantee the consistency self.generator_ema = deepcopy(self.generator) + self.grad_clip = self.train_cfg.get('grad_clip', None) def _parse_test_cfg(self): """Parsing test config and set some attributes for testing.""" @@ -231,10 +232,11 @@ def train_step(self, if (curr_iter + 1) % self.batch_accumulation_steps == 0: if loss_scaler: loss_scaler.unscale_(optimizer['discriminator']) - # note that we do not contain clip_grad procedure + self.clip_grads('discriminator', log_vars_disc) loss_scaler.step(optimizer['discriminator']) # loss_scaler.update will be called in runner.train() else: + self.clip_grads('discriminator', log_vars_disc) optimizer['discriminator'].step() # skip generator training if only train discriminator for current @@ -298,9 +300,11 @@ def train_step(self, if loss_scaler: loss_scaler.unscale_(optimizer['generator']) # note that we do not contain clip_grad procedure + self.clip_grads('generator', log_vars_g) loss_scaler.step(optimizer['generator']) # loss_scaler.update will be called in runner.train() else: + self.clip_grads('generator', log_vars_g) optimizer['generator'].step() log_vars = {} diff --git a/mmgen/models/gans/static_unconditional_gan.py b/mmgen/models/gans/static_unconditional_gan.py index 41b662082..477b9d8e7 100644 --- a/mmgen/models/gans/static_unconditional_gan.py +++ b/mmgen/models/gans/static_unconditional_gan.py @@ -99,6 +99,7 @@ def _parse_train_cfg(self): self.generator_ema = deepcopy(self.generator) self.real_img_key = self.train_cfg.get('real_img_key', 'real_img') + self.grad_clip = self.train_cfg.get('grad_clip', None) def _parse_test_cfg(self): """Parsing test config and set some attributes for testing.""" @@ -207,10 +208,11 @@ def train_step(self, if loss_scaler: loss_scaler.unscale_(optimizer['discriminator']) - # note that we do not contain clip_grad procedure + self.clip_grads('discriminator', log_vars_disc) loss_scaler.step(optimizer['discriminator']) # loss_scaler.update will be called in runner.train() else: + self.clip_grads('discriminator', log_vars_disc) optimizer['discriminator'].step() # skip generator training if only train discriminator for current @@ -264,10 +266,11 @@ def train_step(self, if loss_scaler: loss_scaler.unscale_(optimizer['generator']) - # note that we do not contain clip_grad procedure + self.clip_grads('generator', log_vars_g) loss_scaler.step(optimizer['generator']) # loss_scaler.update will be called in runner.train() else: + self.clip_grads('generator', log_vars_g) optimizer['generator'].step() log_vars = {} diff --git a/tests/test_models/test_basic_conditional_gan.py b/tests/test_models/test_basic_conditional_gan.py index 723e65d24..e30c65962 100644 --- a/tests/test_models/test_basic_conditional_gan.py +++ b/tests/test_models/test_basic_conditional_gan.py @@ -17,12 +17,12 @@ def setup_class(cls): generator=dict( type='SNGANGenerator', output_scale=32, - base_channels=256, + base_channels=16, num_classes=10), discriminator=dict( type='ProjDiscriminator', input_scale=32, - base_channels=128, + base_channels=16, num_classes=10), gan_loss=dict(type='GANLoss', gan_type='hinge'), disc_auxiliary_loss=None, @@ -34,14 +34,14 @@ def setup_class(cls): type='SAGANGenerator', output_scale=32, num_classes=10, - base_channels=256, + base_channels=16, attention_after_nth_block=2, with_spectral_norm=True) cls.disc_cfg = dict( type='SAGANDiscriminator', input_scale=32, num_classes=10, - base_channels=128, + base_channels=16, attention_after_nth_block=1) cls.gan_loss = dict(type='GANLoss', gan_type='hinge') cls.disc_auxiliary_loss = [ @@ -137,6 +137,53 @@ def test_default_dcgan_model_cpu(self): assert isinstance(sagan.disc_auxiliary_losses, nn.ModuleList) assert isinstance(sagan.gen_auxiliary_losses, nn.ModuleList) + # test grad clip + train_cfg = dict(grad_clip=dict(max_norm=10, norm_type=2)) + sagan = BasicConditionalGAN( + self.generator_cfg, + self.disc_cfg, + self.gan_loss, + self.disc_auxiliary_loss, + train_cfg=train_cfg) + # test train step + data = torch.randn((2, 3, 32, 32)) + data_input = dict(img=data, gt_label=label) + optimizer_g = torch.optim.SGD(sngan.generator.parameters(), lr=0.01) + optimizer_d = torch.optim.SGD( + sngan.discriminator.parameters(), lr=0.01) + optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d) + + model_outputs = sagan.train_step(data_input, optim_dict) + assert 'results' in model_outputs + assert 'log_vars' in model_outputs + assert model_outputs['num_samples'] == 2 + assert 'grad_norm_generator' in model_outputs['log_vars'] + assert 'grad_norm_discriminator' in model_outputs['log_vars'] + + # test specific grad clip + train_cfg = dict( + grad_clip=dict(discriminator=dict(max_norm=10, norm_type=2))) + sagan = BasicConditionalGAN( + self.generator_cfg, + self.disc_cfg, + self.gan_loss, + self.disc_auxiliary_loss, + train_cfg=train_cfg) + # test train step + data = torch.randn((2, 3, 32, 32)) + data_input = dict(img=data, gt_label=label) + optimizer_g = torch.optim.SGD(sngan.generator.parameters(), lr=0.01) + optimizer_d = torch.optim.SGD( + sngan.discriminator.parameters(), lr=0.01) + optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d) + + model_outputs = sagan.train_step(data_input, optim_dict) + assert 'results' in model_outputs + assert 'log_vars' in model_outputs + assert model_outputs['num_samples'] == 2 + assert 'grad_norm_generator' not in model_outputs['log_vars'] + assert 'grad_norm_discriminator' in model_outputs['log_vars'] + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') def test_default_dcgan_model_cuda(self): sngan = build_model(self.default_config).cuda() @@ -179,3 +226,50 @@ def test_default_dcgan_model_cuda(self): data_input, optim_dict, running_status=dict(iteration=1)) assert 'loss_disc_fake' in model_outputs['log_vars'] assert 'loss_disc_fake_g' in model_outputs['log_vars'] + + # test grad clip + train_cfg = dict(grad_clip=dict(max_norm=10, norm_type=2)) + sagan = BasicConditionalGAN( + self.generator_cfg, + self.disc_cfg, + self.gan_loss, + self.disc_auxiliary_loss, + train_cfg=train_cfg).cuda() + # test train step + data = torch.randn((2, 3, 32, 32)).cuda() + data_input = dict(img=data, gt_label=label) + optimizer_g = torch.optim.SGD(sngan.generator.parameters(), lr=0.01) + optimizer_d = torch.optim.SGD( + sngan.discriminator.parameters(), lr=0.01) + optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d) + + model_outputs = sagan.train_step(data_input, optim_dict) + assert 'results' in model_outputs + assert 'log_vars' in model_outputs + assert model_outputs['num_samples'] == 2 + assert 'grad_norm_generator' in model_outputs['log_vars'] + assert 'grad_norm_discriminator' in model_outputs['log_vars'] + + # test specific grad clip + train_cfg = dict( + grad_clip=dict(discriminator=dict(max_norm=10, norm_type=2))) + sagan = BasicConditionalGAN( + self.generator_cfg, + self.disc_cfg, + self.gan_loss, + self.disc_auxiliary_loss, + train_cfg=train_cfg).cuda() + # test train step + data = torch.randn((2, 3, 32, 32)).cuda() + data_input = dict(img=data, gt_label=label) + optimizer_g = torch.optim.SGD(sngan.generator.parameters(), lr=0.01) + optimizer_d = torch.optim.SGD( + sngan.discriminator.parameters(), lr=0.01) + optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d) + + model_outputs = sagan.train_step(data_input, optim_dict) + assert 'results' in model_outputs + assert 'log_vars' in model_outputs + assert model_outputs['num_samples'] == 2 + assert 'grad_norm_generator' not in model_outputs['log_vars'] + assert 'grad_norm_discriminator' in model_outputs['log_vars'] diff --git a/tests/test_models/test_static_unconditional_gan.py b/tests/test_models/test_static_unconditional_gan.py index 22fa05edf..fb63efcac 100644 --- a/tests/test_models/test_static_unconditional_gan.py +++ b/tests/test_models/test_static_unconditional_gan.py @@ -127,6 +127,49 @@ def test_default_dcgan_model_cpu(self): assert isinstance(dcgan.disc_auxiliary_losses, nn.ModuleList) assert isinstance(dcgan.gen_auxiliary_losses, nn.ModuleList) + # test grad clip + train_cfg = dict(grad_clip=dict(max_norm=10, norm_type=2)) + dcgan = StaticUnconditionalGAN( + self.generator_cfg, + self.disc_cfg, + self.gan_loss, + train_cfg=train_cfg) + data = torch.randn((2, 3, 16, 16)) + data_input = dict(real_img=data) + optimizer_g = torch.optim.SGD(dcgan.generator.parameters(), lr=0.01) + optimizer_d = torch.optim.SGD( + dcgan.discriminator.parameters(), lr=0.01) + optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d) + + model_outputs = dcgan.train_step(data_input, optim_dict) + assert 'results' in model_outputs + assert 'log_vars' in model_outputs + assert model_outputs['num_samples'] == 2 + assert 'grad_norm_generator' in model_outputs['log_vars'] + assert 'grad_norm_discriminator' in model_outputs['log_vars'] + + # test model specific grad clip + train_cfg = dict( + grad_clip=dict(generator=dict(max_norm=10, norm_type=2))) + dcgan = StaticUnconditionalGAN( + self.generator_cfg, + self.disc_cfg, + self.gan_loss, + train_cfg=train_cfg) + data = torch.randn((2, 3, 16, 16)) + data_input = dict(real_img=data) + optimizer_g = torch.optim.SGD(dcgan.generator.parameters(), lr=0.01) + optimizer_d = torch.optim.SGD( + dcgan.discriminator.parameters(), lr=0.01) + optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d) + + model_outputs = dcgan.train_step(data_input, optim_dict) + assert 'results' in model_outputs + assert 'log_vars' in model_outputs + assert model_outputs['num_samples'] == 2 + assert 'grad_norm_generator' in model_outputs['log_vars'] + assert 'grad_norm_discriminator' not in model_outputs['log_vars'] + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') def test_default_dcgan_model_cuda(self): dcgan = build_model(self.default_config).cuda() @@ -168,3 +211,46 @@ def test_default_dcgan_model_cuda(self): data_input, optim_dict, running_status=dict(iteration=1)) assert 'loss_disc_fake' in model_outputs['log_vars'] assert 'loss_disc_fake_g' in model_outputs['log_vars'] + + # test grad clip + train_cfg = dict(grad_clip=dict(max_norm=10, norm_type=2)) + dcgan = StaticUnconditionalGAN( + self.generator_cfg, + self.disc_cfg, + self.gan_loss, + train_cfg=train_cfg).cuda() + data = torch.randn((2, 3, 16, 16)).cuda() + data_input = dict(real_img=data) + optimizer_g = torch.optim.SGD(dcgan.generator.parameters(), lr=0.01) + optimizer_d = torch.optim.SGD( + dcgan.discriminator.parameters(), lr=0.01) + optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d) + + model_outputs = dcgan.train_step(data_input, optim_dict) + assert 'results' in model_outputs + assert 'log_vars' in model_outputs + assert model_outputs['num_samples'] == 2 + assert 'grad_norm_generator' in model_outputs['log_vars'] + assert 'grad_norm_discriminator' in model_outputs['log_vars'] + + # test model specific grad clip + train_cfg = dict( + grad_clip=dict(generator=dict(max_norm=10, norm_type=2))) + dcgan = StaticUnconditionalGAN( + self.generator_cfg, + self.disc_cfg, + self.gan_loss, + train_cfg=train_cfg).cuda() + data = torch.randn((2, 3, 16, 16)).cuda() + data_input = dict(real_img=data) + optimizer_g = torch.optim.SGD(dcgan.generator.parameters(), lr=0.01) + optimizer_d = torch.optim.SGD( + dcgan.discriminator.parameters(), lr=0.01) + optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d) + + model_outputs = dcgan.train_step(data_input, optim_dict) + assert 'results' in model_outputs + assert 'log_vars' in model_outputs + assert model_outputs['num_samples'] == 2 + assert 'grad_norm_generator' in model_outputs['log_vars'] + assert 'grad_norm_discriminator' not in model_outputs['log_vars']