diff --git a/configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py b/configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py index 9d09137f2..44eda7199 100644 --- a/configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py +++ b/configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py @@ -37,7 +37,8 @@ pass_training_status=True) # Note set your inception_pkl's path -inception_pkl = 'work_dirs/inception_pkl/imagenet.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet.pkl') evaluation = dict( type='GenerativeEvalHook', interval=10000, diff --git a/configs/biggan/biggan_torch-sn_imagenet1k_128x128_b32x8_1500k.py b/configs/biggan/biggan_torch-sn_imagenet1k_128x128_b32x8_1500k.py index bafbeeea9..78404617a 100644 --- a/configs/biggan/biggan_torch-sn_imagenet1k_128x128_b32x8_1500k.py +++ b/configs/biggan/biggan_torch-sn_imagenet1k_128x128_b32x8_1500k.py @@ -40,7 +40,8 @@ pass_training_status=True) # Note set your inception_pkl's path -inception_pkl = 'work_dirs/inception_pkl/imagenet.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet.pkl') evaluation = dict( type='GenerativeEvalHook', interval=10000, diff --git a/configs/improved_ddpm/README.md b/configs/improved_ddpm/README.md index f5cd098ec..3187cd99b 100644 --- a/configs/improved_ddpm/README.md +++ b/configs/improved_ddpm/README.md @@ -49,7 +49,7 @@ Denoising diffusion probabilistic models (DDPM) are a class of generative models For FID evaluation, we follow the pipeline of [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch/blob/98459431a5d618d644d54cd1e9fceb1e5045648d/calculate_inception_moments.py#L52), where the whole training set is adopted to extract inception statistics, and Pytorch Studio GAN uses 50000 randomly selected samples. Besides, we also use [Tero's Inception](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt) for feature extraction. -You can download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k-64x64](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet_64x64.pkl). +MMGen will automatically download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k-64x64](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet_64x64.pkl). You can use following commands to extract those inception states by yourself. diff --git a/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py b/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py index 9f1b125cb..295294762 100644 --- a/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py +++ b/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py @@ -36,7 +36,8 @@ is_dynamic_ddp=False, # Note that this flag should be False. pass_training_status=True) -inception_pkl = './work_dirs/inception_pkl/cifar10.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/cifar10.pkl') metrics = dict( fid50k=dict( type='FID', diff --git a/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_imagenet1k_64x64_b8x16_1500k.py b/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_imagenet1k_64x64_b8x16_1500k.py index 3eeb7df0a..5868aa7f3 100644 --- a/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_imagenet1k_64x64_b8x16_1500k.py +++ b/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_imagenet1k_64x64_b8x16_1500k.py @@ -39,7 +39,8 @@ is_dynamic_ddp=False, # Note that this flag should be False. pass_training_status=True) -inception_pkl = './work_dirs/inception_pkl/imagenet_64x64.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet_64x64.pkl') metrics = dict( fid50k=dict( type='FID', diff --git a/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_imagenet1k_64x64_b8x16_1500k.py b/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_imagenet1k_64x64_b8x16_1500k.py index 3bdc54d27..eabb00a2f 100644 --- a/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_imagenet1k_64x64_b8x16_1500k.py +++ b/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_imagenet1k_64x64_b8x16_1500k.py @@ -36,7 +36,8 @@ is_dynamic_ddp=False, # Note that this flag should be False. pass_training_status=True) -inception_pkl = './work_dirs/inception_pkl/imagenet_64x64.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet_64x64.pkl') metrics = dict( fid50k=dict( type='FID', diff --git a/configs/sagan/README.md b/configs/sagan/README.md index 2238d8033..ce25f89bd 100644 --- a/configs/sagan/README.md +++ b/configs/sagan/README.md @@ -71,7 +71,7 @@ For IS metric, our implementation is different from PyTorch-Studio GAN in the fo For FID evaluation, we follow the pipeline of [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch/blob/98459431a5d618d644d54cd1e9fceb1e5045648d/calculate_inception_moments.py#L52), where the whole training set is adopted to extract inception statistics, and Pytorch Studio GAN uses 50000 randomly selected samples. Besides, we also use [Tero's Inception](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt) for feature extraction. -You can download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet.pkl). +MMGen will automatically download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet.pkl). You can use following commands to extract those inception states by yourself. ``` diff --git a/configs/sagan/sagan_128_woReLUinplace_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b64x4.py b/configs/sagan/sagan_128_woReLUinplace_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b64x4.py index ea431e71a..430320548 100644 --- a/configs/sagan/sagan_128_woReLUinplace_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b64x4.py +++ b/configs/sagan/sagan_128_woReLUinplace_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b64x4.py @@ -18,7 +18,8 @@ interval=1000) ] -inception_pkl = './work_dirs/inception_pkl/imagenet.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/configs/sagan/sagan_128_woReLUinplace_noaug_bigGAN_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b32x8.py b/configs/sagan/sagan_128_woReLUinplace_noaug_bigGAN_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b32x8.py index 5e1088819..33c7234db 100644 --- a/configs/sagan/sagan_128_woReLUinplace_noaug_bigGAN_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b32x8.py +++ b/configs/sagan/sagan_128_woReLUinplace_noaug_bigGAN_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b32x8.py @@ -45,7 +45,8 @@ priority='VERY_HIGH') ] -inception_pkl = './work_dirs/inception_pkl/imagenet.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/configs/sagan/sagan_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py b/configs/sagan/sagan_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py index 98abcf45e..5c04d9525 100644 --- a/configs/sagan/sagan_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py +++ b/configs/sagan/sagan_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py @@ -25,7 +25,8 @@ interval=1000) ] -inception_pkl = './work_dirs/inception_pkl/cifar10.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/cifar10.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/configs/sagan/sagan_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py b/configs/sagan/sagan_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py index 395ca8ce9..7cb487208 100644 --- a/configs/sagan/sagan_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py +++ b/configs/sagan/sagan_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py @@ -18,7 +18,8 @@ interval=1000) ] -inception_pkl = './work_dirs/inception_pkl/cifar10.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/cifar10.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/configs/sngan_proj/README.md b/configs/sngan_proj/README.md index 09a4f93a6..3660b3e65 100644 --- a/configs/sngan_proj/README.md +++ b/configs/sngan_proj/README.md @@ -70,7 +70,7 @@ For IS metric, our implementation is different from PyTorch-Studio GAN in the fo For FID evaluation, we follow the pipeline of [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch/blob/98459431a5d618d644d54cd1e9fceb1e5045648d/calculate_inception_moments.py#L52), where the whole training set is adopted to extract inception statistics, and Pytorch Studio GAN uses 50000 randomly selected samples. Besides, we also use [Tero's Inception](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt) for feature extraction. -You can download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet.pkl). +MMGen will automatically download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet.pkl). You can use following commands to extract those inception states by yourself. ``` diff --git a/configs/sngan_proj/sngan_proj_128_wReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py b/configs/sngan_proj/sngan_proj_128_wReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py index 90078d2b8..e51141095 100644 --- a/configs/sngan_proj/sngan_proj_128_wReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py +++ b/configs/sngan_proj/sngan_proj_128_wReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py @@ -30,7 +30,8 @@ log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')]) -inception_pkl = './work_dirs/inception_pkl/imagenet.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/configs/sngan_proj/sngan_proj_128_woReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py b/configs/sngan_proj/sngan_proj_128_woReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py index 96f86a3cc..5bfed2de3 100644 --- a/configs/sngan_proj/sngan_proj_128_woReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py +++ b/configs/sngan_proj/sngan_proj_128_woReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py @@ -24,7 +24,8 @@ log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')]) -inception_pkl = './work_dirs/inception_pkl/imagenet.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/configs/sngan_proj/sngan_proj_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py b/configs/sngan_proj/sngan_proj_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py index 4d6bd2d91..699d64f0f 100644 --- a/configs/sngan_proj/sngan_proj_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py +++ b/configs/sngan_proj/sngan_proj_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py @@ -28,7 +28,8 @@ interval=5000) ] -inception_pkl = './work_dirs/inception_pkl/cifar10.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/cifar10.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/configs/sngan_proj/sngan_proj_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py b/configs/sngan_proj/sngan_proj_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py index 8db95fea7..86da61a49 100644 --- a/configs/sngan_proj/sngan_proj_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py +++ b/configs/sngan_proj/sngan_proj_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py @@ -22,7 +22,8 @@ interval=5000) ] -inception_pkl = './work_dirs/inception_pkl/cifar10.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/cifar10.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/mmgen/core/evaluation/metrics.py b/mmgen/core/evaluation/metrics.py index ba7802bda..647ba76da 100644 --- a/mmgen/core/evaluation/metrics.py +++ b/mmgen/core/evaluation/metrics.py @@ -57,6 +57,11 @@ def load_inception(inception_args, metric): inceptoin_type = _inception_args.pop('type', None) if torch.__version__ < '1.6.0': + # reset inception_args for FID (Inception for IS do not use + # inception_args) + if metric == 'FID': + _inception_args = dict(normalize_input=False) + mmcv.print_log( 'Current Pytorch Version not support script module, load ' 'Inception Model from torch model zoo. If you want to use ' @@ -118,7 +123,7 @@ def _load_inception_torch(inception_args, metric): assert metric in ['FID', 'IS'] if metric == 'FID': inception_model = InceptionV3([3], **inception_args) - elif metric == 'IS': + else: # metric == 'IS' inception_model = inception_v3(pretrained=True, transform_input=False) mmcv.print_log( 'Load Inception V3 Network from Pytorch Model Zoo ' @@ -505,15 +510,23 @@ def __init__(self, def prepare(self): """Prepare for evaluating models with this metric.""" # if `inception_pkl` is provided, read mean and cov stat - if self.inception_pkl is not None and mmcv.is_filepath( - self.inception_pkl): - with open(self.inception_pkl, 'rb') as f: + if self.inception_pkl is not None: + if self.inception_pkl[:4] == 'http': + inception_path = download_from_url(self.inception_pkl) + elif mmcv.is_filepath(self.inception_pkl): + inception_path = self.inception_pkl + else: + raise FileNotFoundError('Cannot load inception pkl from ' + f'{self.inception_pkl}') + + # load from path + with open(inception_path, 'rb') as f: reference = pickle.load(f) self.real_mean = reference['mean'] self.real_cov = reference['cov'] - mmcv.print_log( - f'Load reference inception pkl from {self.inception_pkl}', - 'mmgen') + mmcv.print_log( + f'Load reference inception pkl from {self.inception_pkl}', + 'mmgen') self.num_real_feeded = self.num_images @torch.no_grad() diff --git a/tests/test_cores/test_metrics.py b/tests/test_cores/test_metrics.py index a5aa46a62..2e5d27cc3 100644 --- a/tests/test_cores/test_metrics.py +++ b/tests/test_cores/test_metrics.py @@ -11,32 +11,45 @@ from mmgen.models import build_model from mmgen.models.architectures import InceptionV3 -# def test_inception_download(): -# from mmgen.core.evaluation.metrics import load_inception -# from mmgen.utils import MMGEN_CACHE_DIR -# args_FID_pytorch = dict(type='pytorch', normalize_input=False) -# args_FID_tero = dict(type='StyleGAN', inception_path='') -# args_IS_pytorch = dict(type='pytorch') -# args_IS_tero = dict( -# type='StyleGAN', -# inception_path=osp.join(MMGEN_CACHE_DIR, 'inception-2015-12-05.pt')) +def test_inception_download(): + from mmgen.core.evaluation.metrics import load_inception + from mmgen.utils import MMGEN_CACHE_DIR + + args_FID_pytorch = dict(type='pytorch', normalize_input=False) + args_FID_tero = dict(type='StyleGAN') + args_IS_pytorch = dict(type='pytorch') + args_IS_tero = dict( + type='StyleGAN', + inception_path=osp.join(MMGEN_CACHE_DIR, 'inception-2015-12-05.pt')) + + arg_list = [args_FID_pytorch, args_FID_tero, args_IS_pytorch, args_IS_tero] + metric_list = ['FID', 'FID', 'IS', 'IS'] + tar_style_list = ['pytorch', 'StyleGAN', 'pytorch', 'StyleGAN'] + + for inception_args, metric, tar_style in zip(arg_list, metric_list, + tar_style_list): + model, style = load_inception(inception_args, metric) + + if torch.__version__ < '1.6.0': + print(inception_args, metric, tar_style) + assert style == 'pytorch' + else: + assert style == tar_style + + args_empty = '' + with pytest.raises(TypeError): + load_inception(args_empty, 'FID') + + # pt lower than this version cannot load Tero's inception and direct use + # torch ones, only test this for pt >= 1.6 + if torch.__version__ >= '1.6.0': + args_error_path = dict(type='StyleGAN', inception_path='error-path') + with pytest.raises(RuntimeError): + load_inception(args_error_path, 'FID') -# tar_style_list = ['pytorch', 'StyleGAN', 'pytorch', 'StyleGAN'] - -# for inception_args, metric, tar_style in zip( -# [args_FID_pytorch, args_FID_tero, args_IS_pytorch, args_IS_tero], -# ['FID', 'FID', 'IS', 'IS'], tar_style_list): -# model, style = load_inception(inception_args, metric) -# assert style == tar_style - -# args_empty = '' -# with pytest.raises(TypeError) as exc_info: -# load_inception(args_empty, 'FID') - -# args_error_path = dict(type='StyleGAN', inception_path='error-path') -# with pytest.raises(RuntimeError) as exc_info: -# load_inception(args_error_path, 'FID') + with pytest.raises(AssertionError): + load_inception(dict(type='pytorch', normalize_input=False), 'PPL') def test_swd_metric(): @@ -144,21 +157,52 @@ def test_fid(self): assert fid_score > 0 and mean > 0 and cov > 0 # To reduce the size of git repo, we remove the following test - # fid = FID( - # 3, - # inception_args=dict( - # normalize_input=False, load_fid_inception=False), - # inception_pkl=osp.join( - # osp.dirname(__file__), '..', 'data', 'test_dirty.pkl')) - # assert fid.num_real_feeded == 3 - # for b in self.reals: - # fid.feed(b, 'reals') - - # for b in self.fakes: - # fid.feed(b, 'fakes') - - # fid_score, mean, cov = fid.summary() - # assert fid_score > 0 and mean > 0 and cov > 0 + + inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/cifar10.pkl') + fid = FID( + 3, + inception_args=dict( + normalize_input=False, load_fid_inception=False), + inception_pkl=inception_pkl) + fid.prepare() + assert fid.num_real_feeded == 3 + for b in self.reals: + fid.feed(b, 'reals') + + for b in self.fakes: + fid.feed(b, 'fakes') + + fid_score, mean, cov = fid.summary() + assert fid_score > 0 and mean > 0 and cov > 0 + + # test load + inception_pkl = osp.expanduser('~/.cache/openmmlab/mmgen/cifar10.pkl') + fid = FID( + 3, + inception_args=dict( + normalize_input=False, load_fid_inception=False), + inception_pkl=inception_pkl) + fid.prepare() + assert fid.num_real_feeded == 3 + for b in self.reals: + fid.feed(b, 'reals') + + for b in self.fakes: + fid.feed(b, 'fakes') + + fid_score, mean, cov = fid.summary() + assert fid_score > 0 and mean > 0 and cov > 0 + + # test raise load error + inception_pkl = 'wrong_path' + fid = FID( + 3, + inception_args=dict( + normalize_input=False, load_fid_inception=False), + inception_pkl=inception_pkl) + with pytest.raises(FileNotFoundError): + fid.prepare() class TestPR: