diff --git a/README.md b/README.md index f23b1e1..0381450 100644 --- a/README.md +++ b/README.md @@ -152,10 +152,16 @@ Also one flag to use `--multi-gpus` [Aim](https://github.com/aimhubio/aim) is an open-source experiment tracker that logs your training runs, enables a beautiful UI to compare them and an API to query them programmatically. -You can specify Aim logs directory with `--aim_repo` flag, otherwise logs will be stored in the current directory +First you need to install `aim` with `pip` ```bash -$ lightweight_gan --data ./path/to/images --image-size 512 --aim_repo ./path/to/logs/ +$ pip install aim +``` + +Next, you can specify Aim logs directory with `--aim_repo` flag, otherwise logs will be stored in the current directory + +```bash +$ lightweight_gan --data ./path/to/images --image-size 512 --use-aim --aim_repo ./path/to/logs/ ``` Execute `aim up --repo ./path/to/logs/` to run Aim UI on your server. diff --git a/lightweight_gan/cli.py b/lightweight_gan/cli.py index 6527008..9a06f13 100644 --- a/lightweight_gan/cli.py +++ b/lightweight_gan/cli.py @@ -116,9 +116,10 @@ def train_from_folder( seed = 42, amp = False, show_progress = False, - use_aim = True, + use_aim = False, aim_repo = None, - aim_run_hash = None + aim_run_hash = None, + load_strict = True ): num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8) @@ -149,7 +150,8 @@ def train_from_folder( calculate_fid_every = calculate_fid_every, calculate_fid_num_images = calculate_fid_num_images, clear_fid_cache = clear_fid_cache, - amp = amp + amp = amp, + load_strict = load_strict ) if generate: diff --git a/lightweight_gan/lightweight_gan.py b/lightweight_gan/lightweight_gan.py index 42b0f40..d4c5144 100644 --- a/lightweight_gan/lightweight_gan.py +++ b/lightweight_gan/lightweight_gan.py @@ -1,5 +1,4 @@ import os -import aim import json import multiprocessing from random import random @@ -961,6 +960,7 @@ def __init__( use_aim = True, aim_repo = None, aim_run_hash = None, + load_strict = True, *args, **kwargs ): @@ -1038,14 +1038,23 @@ def __init__( self.syncbatchnorm = is_ddp + self.load_strict = load_strict + self.amp = amp self.G_scaler = GradScaler(enabled = self.amp) self.D_scaler = GradScaler(enabled = self.amp) self.run = None self.hparams = hparams + if self.is_main and use_aim: - self.run = aim.Run(run_hash=aim_run_hash, repo=aim_repo) + try: + import aim + self.aim = aim + except ImportError: + print('unable to import aim experiment tracker - please run `pip install aim` first') + + self.run = self.aim.Run(run_hash=aim_run_hash, repo=aim_repo) self.run['hparams'] = hparams @property @@ -1347,7 +1356,7 @@ def image_to_pil(image): aim_images = [] for image in images: im = image_to_pil(image) - aim_images.append(aim.Image(im, caption=f'#{idx}')) + aim_images.append(self.aim.Image(im, caption=f'#{idx}')) self.run.track(value=aim_images, name='generated', step=self.steps, @@ -1362,7 +1371,7 @@ def image_to_pil(image): aim_images = [] for idx, image in enumerate(generated_images): im = image_to_pil(image) - aim_images.append(aim.Image(im, caption=f'#{idx}')) + aim_images.append(self.aim.Image(im, caption=f'#{idx}')) self.run.track(value=aim_images, name='generated', step=self.steps, @@ -1376,7 +1385,7 @@ def image_to_pil(image): aim_images = [] for idx, image in enumerate(generated_images): im = image_to_pil(image) - aim_images.append(aim.Image(im, caption=f'EMA #{idx}')) + aim_images.append(self.aim.Image(im, caption=f'EMA #{idx}')) self.run.track(value=aim_images, name='generated', step=self.steps, @@ -1597,7 +1606,7 @@ def load(self, num=-1, print_version=True): print(f"loading from version {load_data['version']}") try: - self.GAN.load_state_dict(load_data['GAN']) + self.GAN.load_state_dict(load_data['GAN'], strict = self.load_strict) except Exception as e: saved_version = load_data['version'] print('unable to load save model. please try downgrading the package to the version specified by the saved model (to do so, just run `pip install lightweight-gan=={saved_version}`') diff --git a/lightweight_gan/version.py b/lightweight_gan/version.py index 66d9d1e..eed48b7 100644 --- a/lightweight_gan/version.py +++ b/lightweight_gan/version.py @@ -1 +1 @@ -__version__ = '0.22.1' +__version__ = '0.22.3' diff --git a/setup.py b/setup.py index aaee645..8ec6475 100644 --- a/setup.py +++ b/setup.py @@ -33,8 +33,7 @@ 'retry', 'torch>=1.10', 'torchvision', - 'tqdm', - 'aim' + 'tqdm' ], classifiers=[ 'Development Status :: 4 - Beta',