diff --git a/lightweight_gan/cli.py b/lightweight_gan/cli.py index 74c2207..f7f5534 100644 --- a/lightweight_gan/cli.py +++ b/lightweight_gan/cli.py @@ -106,6 +106,8 @@ def train_from_folder( dataset_aug_prob = 0., multi_gpus = False, calculate_fid_every = None, + calculate_fid_num_images = 12800, + clear_fid_cache = False, seed = 42, amp = False ): @@ -134,6 +136,8 @@ def train_from_folder( aug_types = cast_list(aug_types), dataset_aug_prob = dataset_aug_prob, calculate_fid_every = calculate_fid_every, + calculate_fid_num_images = calculate_fid_num_images, + clear_fid_cache = clear_fid_cache, amp = amp ) diff --git a/lightweight_gan/lightweight_gan.py b/lightweight_gan/lightweight_gan.py index af188f0..59196b5 100644 --- a/lightweight_gan/lightweight_gan.py +++ b/lightweight_gan/lightweight_gan.py @@ -43,7 +43,6 @@ NUM_CORES = multiprocessing.cpu_count() EXTS = ['jpg', 'jpeg', 'png'] -CALC_FID_NUM_IMAGES = 12800 # helpers @@ -742,6 +741,8 @@ def __init__( aug_types = ['translation', 'cutout'], dataset_aug_prob = 0., calculate_fid_every = None, + calculate_fid_num_images = 12800, + clear_fid_cache = False, is_ddp = False, rank = 0, world_size = 1, @@ -759,6 +760,8 @@ def __init__( self.base_dir = base_dir self.results_dir = base_dir / results_dir self.models_dir = base_dir / models_dir + self.fid_dir = base_dir / 'fid' / name + self.config_path = self.models_dir / name / '.config.json' assert is_power_of_two(image_size), 'image size must be a power of 2 (64, 128, 256, 512, 1024)' @@ -808,6 +811,8 @@ def __init__( self.dataset_aug_prob = dataset_aug_prob self.calculate_fid_every = calculate_fid_every + self.calculate_fid_num_images = calculate_fid_num_images + self.clear_fid_cache = clear_fid_cache self.is_ddp = is_ddp self.is_main = rank == 0 @@ -1058,7 +1063,7 @@ def train(self): self.evaluate(floor(self.steps / self.evaluate_every), num_image_tiles = self.num_image_tiles) if exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0: - num_batches = math.ceil(CALC_FID_NUM_IMAGES / self.batch_size) + num_batches = math.ceil(self.calculate_fid_num_images / self.batch_size) fid = self.calculate_fid(num_batches) self.last_fid = fid @@ -1126,21 +1131,25 @@ def calculate_fid(self, num_batches): from pytorch_fid import fid_score torch.cuda.empty_cache() - real_path = str(self.results_dir / self.name / 'fid_real') + '/' - fake_path = str(self.results_dir / self.name / 'fid_fake') + '/' + real_path = self.fid_dir / 'real' + fake_path = self.fid_dir / 'fake' # remove any existing files used for fid calculation and recreate directories - rmtree(real_path, ignore_errors=True) - rmtree(fake_path, ignore_errors=True) - os.makedirs(real_path) - os.makedirs(fake_path) + if not real_path.exists() or self.clear_fid_cache: + rmtree(real_path, ignore_errors=True) + os.makedirs(real_path) - for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'): - real_batch = next(self.loader) - for k in range(real_batch.size(0)): - torchvision.utils.save_image(real_batch[k, :, :, :], real_path + '{}.png'.format(k + batch_num * self.batch_size)) + for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'): + real_batch = next(self.loader) + for k, image in enumerate(real_batch.unbind(0)): + ind = k + batch_num * self.batch_size + torchvision.utils.save_image(image, real_path / f'{ind}.png') # generate a bunch of fake images in results / name / fid_fake + + rmtree(fake_path, ignore_errors=True) + os.makedirs(fake_path) + self.GAN.eval() ext = self.image_extension @@ -1154,10 +1163,11 @@ def calculate_fid(self, num_batches): # moving averages generated_images = self.generate_truncated(self.GAN.GE, latents) - for j in range(generated_images.size(0)): - torchvision.utils.save_image(generated_images[j, :, :, :], str(Path(fake_path) / f'{str(j + batch_num * self.batch_size)}-ema.{ext}')) + for j, image in enumerate(generated_images.unbind(0)): + ind = j + batch_num * self.batch_size + torchvision.utils.save_image(image, str(fake_path / f'{str(ind)}-ema.{ext}')) - return fid_score.calculate_fid_given_paths([real_path, fake_path], 256, latents.device, 2048) + return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, latents.device, 2048) @torch.no_grad() def generate_truncated(self, G, style, trunc_psi = 0.75, num_image_tiles = 8): @@ -1224,6 +1234,7 @@ def init_folders(self): def clear(self): rmtree(str(self.models_dir / self.name), True) rmtree(str(self.results_dir / self.name), True) + rmtree(str(self.fid_dir), True) rmtree(str(self.config_path), True) self.init_folders() diff --git a/lightweight_gan/version.py b/lightweight_gan/version.py index 66e314a..538eb5d 100644 --- a/lightweight_gan/version.py +++ b/lightweight_gan/version.py @@ -1 +1 @@ -__version__ = '0.16.3' +__version__ = '0.16.4'