Skip to content

Commit

Permalink
do not recalculate real images for fid score at every interval, by de…
Browse files Browse the repository at this point in the history
…fault. offer means to clear fid real images cache, and be able to set the number of images for fid calculation
  • Loading branch information
lucidrains committed Jan 15, 2021
1 parent 8643aab commit eba115b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 16 deletions.
4 changes: 4 additions & 0 deletions lightweight_gan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def train_from_folder(
dataset_aug_prob = 0.,
multi_gpus = False,
calculate_fid_every = None,
calculate_fid_num_images = 12800,
clear_fid_cache = False,
seed = 42,
amp = False
):
Expand Down Expand Up @@ -134,6 +136,8 @@ def train_from_folder(
aug_types = cast_list(aug_types),
dataset_aug_prob = dataset_aug_prob,
calculate_fid_every = calculate_fid_every,
calculate_fid_num_images = calculate_fid_num_images,
clear_fid_cache = clear_fid_cache,
amp = amp
)

Expand Down
41 changes: 26 additions & 15 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@

NUM_CORES = multiprocessing.cpu_count()
EXTS = ['jpg', 'jpeg', 'png']
CALC_FID_NUM_IMAGES = 12800

# helpers

Expand Down Expand Up @@ -742,6 +741,8 @@ def __init__(
aug_types = ['translation', 'cutout'],
dataset_aug_prob = 0.,
calculate_fid_every = None,
calculate_fid_num_images = 12800,
clear_fid_cache = False,
is_ddp = False,
rank = 0,
world_size = 1,
Expand All @@ -759,6 +760,8 @@ def __init__(
self.base_dir = base_dir
self.results_dir = base_dir / results_dir
self.models_dir = base_dir / models_dir
self.fid_dir = base_dir / 'fid' / name

self.config_path = self.models_dir / name / '.config.json'

assert is_power_of_two(image_size), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
Expand Down Expand Up @@ -808,6 +811,8 @@ def __init__(
self.dataset_aug_prob = dataset_aug_prob

self.calculate_fid_every = calculate_fid_every
self.calculate_fid_num_images = calculate_fid_num_images
self.clear_fid_cache = clear_fid_cache

self.is_ddp = is_ddp
self.is_main = rank == 0
Expand Down Expand Up @@ -1058,7 +1063,7 @@ def train(self):
self.evaluate(floor(self.steps / self.evaluate_every), num_image_tiles = self.num_image_tiles)

if exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0:
num_batches = math.ceil(CALC_FID_NUM_IMAGES / self.batch_size)
num_batches = math.ceil(self.calculate_fid_num_images / self.batch_size)
fid = self.calculate_fid(num_batches)
self.last_fid = fid

Expand Down Expand Up @@ -1126,21 +1131,25 @@ def calculate_fid(self, num_batches):
from pytorch_fid import fid_score
torch.cuda.empty_cache()

real_path = str(self.results_dir / self.name / 'fid_real') + '/'
fake_path = str(self.results_dir / self.name / 'fid_fake') + '/'
real_path = self.fid_dir / 'real'
fake_path = self.fid_dir / 'fake'

# remove any existing files used for fid calculation and recreate directories
rmtree(real_path, ignore_errors=True)
rmtree(fake_path, ignore_errors=True)
os.makedirs(real_path)
os.makedirs(fake_path)
if not real_path.exists() or self.clear_fid_cache:
rmtree(real_path, ignore_errors=True)
os.makedirs(real_path)

for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'):
real_batch = next(self.loader)
for k in range(real_batch.size(0)):
torchvision.utils.save_image(real_batch[k, :, :, :], real_path + '{}.png'.format(k + batch_num * self.batch_size))
for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'):
real_batch = next(self.loader)
for k, image in enumerate(real_batch.unbind(0)):
ind = k + batch_num * self.batch_size
torchvision.utils.save_image(image, real_path / f'{ind}.png')

# generate a bunch of fake images in results / name / fid_fake

rmtree(fake_path, ignore_errors=True)
os.makedirs(fake_path)

self.GAN.eval()
ext = self.image_extension

Expand All @@ -1154,10 +1163,11 @@ def calculate_fid(self, num_batches):
# moving averages
generated_images = self.generate_truncated(self.GAN.GE, latents)

for j in range(generated_images.size(0)):
torchvision.utils.save_image(generated_images[j, :, :, :], str(Path(fake_path) / f'{str(j + batch_num * self.batch_size)}-ema.{ext}'))
for j, image in enumerate(generated_images.unbind(0)):
ind = j + batch_num * self.batch_size
torchvision.utils.save_image(image, str(fake_path / f'{str(ind)}-ema.{ext}'))

return fid_score.calculate_fid_given_paths([real_path, fake_path], 256, latents.device, 2048)
return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, latents.device, 2048)

@torch.no_grad()
def generate_truncated(self, G, style, trunc_psi = 0.75, num_image_tiles = 8):
Expand Down Expand Up @@ -1224,6 +1234,7 @@ def init_folders(self):
def clear(self):
rmtree(str(self.models_dir / self.name), True)
rmtree(str(self.results_dir / self.name), True)
rmtree(str(self.fid_dir), True)
rmtree(str(self.config_path), True)
self.init_folders()

Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.16.3'
__version__ = '0.16.4'

0 comments on commit eba115b

Please sign in to comment.