Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix custom ops bug for pytorch 1.12 and onwards #36

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def project(
@click.option('--init-lr', '-lr', 'initial_learning_rate', type=float, help='Initial learning rate of the optimization process', default=0.1, show_default=True)
@click.option('--constant-lr', 'constant_learning_rate', is_flag=True, help='Add flag to use a constant learning rate throughout the optimization (turn off the rampup/rampdown)')
@click.option('--reg-noise-weight', '-regw', 'regularize_noise_weight', type=float, help='Noise weight regularization', default=1e5, show_default=True)
@click.option('--seed', type=int, help='Random seed', default=303, show_default=True)
@click.option('--seed', type=int, help='Torch random number generator seed used to determine input and synthesis noise', default=303, show_default=True)
@click.option('--stabilize-projection', is_flag=True, help='Add flag to stabilize the latent space/anchor to w_avg, making it easier to project (only for StyleGAN3 config-r/t models)')
# Video options
@click.option('--save-video', '-video', is_flag=True, help='Save an mp4 video of optimization progress')
Expand All @@ -371,7 +371,7 @@ def project(
# Options on which space to project to (W or W+) and where to start: the middle point of W (w_avg) or a specific seed
@click.option('--project-in-wplus', '-wplus', is_flag=True, help='Project in the W+ latent space')
@click.option('--start-wavg', '-wavg', type=bool, help='Start with the average W vector, ootherwise will start from a random seed (provided by user)', default=True, show_default=True)
@click.option('--projection-seed', type=int, help='Seed to start projection from', default=None, show_default=True)
@click.option('--projection-seed', type=int, help='Seed vector to use as the starting point of the projection', default=None, show_default=True)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi to use in projection when using a projection seed', default=0.7, show_default=True)
# Decide the loss to use when projecting (all other apart from o.g. StyleGAN2's are experimental, you can select the VGG16 features/layers to use in the im2sgan loss)
@click.option('--loss-paper', '-loss', type=click.Choice(['sgan2', 'im2sgan', 'discriminator', 'clip']), help='Loss to use (if using "im2sgan", make sure to norm the VGG16 features)', default='sgan2', show_default=True)
Expand All @@ -380,6 +380,7 @@ def project(
@click.option('--vgg-sqrt-normed', 'sqrt_normed', is_flag=True, help='Add flag to norm the VGG16 features by the square root of the number of elements per layer that was used')
# Extra parameters for saving the results
@click.option('--save-every-step', '-saveall', is_flag=True, help='Save every step taken in the projection (save both the dlatent as a.npy and its respective image).')
@click.option('--save-n-step', type=int, help='Save every n steps taken in the projection', default=1, show_default=True)
@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'projection'), show_default=True, metavar='DIR')
@click.option('--description', '-desc', type=str, help='Extra description to add to the experiment name', default='')
def run_projection(
Expand All @@ -404,6 +405,7 @@ def run_projection(
normed: bool,
sqrt_normed: bool,
save_every_step: bool,
save_n_step: int,
outdir: str,
description: str,
):
Expand All @@ -415,6 +417,7 @@ def run_projection(
python projector.py --target=~/mytarget.png --project-in-wplus --save-video --num-steps=5000 \\
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
"""
# Set PyTorch random number generator to user provided seed, or the default of 303
torch.manual_seed(seed)

# If we're not starting from the W midpoint, assert the user fed a seed to start from
Expand Down Expand Up @@ -489,6 +492,7 @@ def run_projection(
'seed': seed,
'video_fps': fps,
'save_every_step': save_every_step,
'save_n_step': save_n_step,
'run_config': run_config
}
# Save the run configuration
Expand All @@ -507,36 +511,43 @@ def run_projection(
result_name, npy_name = f'{result_name}_seed-{projection_seed}', f'{npy_name}_seed-{projection_seed}'

# Save the target image
target_pil.save(os.path.join(run_dir, 'target.jpg'))
target_pil.save(os.path.join(run_dir, 'target.png'))

if save_every_step:
# Save every projected frame and W vector. TODO: This can be optimized to be saved as training progresses
# Save every projected frame and W vector, except the final step.
n_digits = int(np.log10(num_steps)) + 1 if num_steps > 0 else 1
for step in tqdm(range(num_steps), desc='Saving projection results', unit='steps'):
for step in tqdm(range(0, num_steps, save_n_step), desc='Saving projection results', unit='steps'):
print(f'{step} / {num_steps}')
w = projected_w_steps[step]
synth_image = gen_utils.w_to_img(G, dlatents=w, noise_mode='const')[0]
PIL.Image.fromarray(synth_image, 'RGB').save(f'{result_name}_step{step:0{n_digits}d}.jpg')
PIL.Image.fromarray(synth_image, 'RGB').save(f'{result_name}_step{step:0{n_digits}d}.png')
np.save(f'{npy_name}_step{step:0{n_digits}d}.npy', w.unsqueeze(0).cpu().numpy())
# Save the final projected frame and W vector.
print('Saving final projection results...')
projected_w = projected_w_steps[-1]
synth_image = gen_utils.w_to_img(G, dlatents=projected_w, noise_mode='const')[0]
PIL.Image.fromarray(synth_image, 'RGB').save(f'{result_name}_{num_steps}_final.png')
np.save(f'{npy_name}_{num_steps}_final.npy', projected_w.unsqueeze(0).cpu().numpy())
else:
# Save only the final projected frame and W vector.
print('Saving projection results...')
projected_w = projected_w_steps[-1]
synth_image = gen_utils.w_to_img(G, dlatents=projected_w, noise_mode='const')[0]
PIL.Image.fromarray(synth_image, 'RGB').save(f'{result_name}_final.jpg')
np.save(f'{npy_name}_final.npy', projected_w.unsqueeze(0).cpu().numpy())
PIL.Image.fromarray(synth_image, 'RGB').save(f'{result_name}_{num_steps}_final.png')
np.save(f'{npy_name}_{num_steps}_final.npy', projected_w.unsqueeze(0).cpu().numpy())

# Save the optimization video and compress it if so desired
if save_video:
video = imageio.get_writer(f'{result_name}.mp4', mode='I', fps=fps, codec='libx264', bitrate='16M')
print(f'Saving optimization progress video "{result_name}.mp4"')
video = imageio.get_writer(f'{result_name}_{num_steps}.mp4', mode='I', fps=fps, codec='libx264', bitrate='16M')
print(f'Saving optimization progress video "{result_name}_{num_steps}.mp4"')
for projected_w in projected_w_steps:
synth_image = gen_utils.w_to_img(G, dlatents=projected_w, noise_mode='const')[0]
video.append_data(np.concatenate([target_uint8, synth_image], axis=1)) # left side target, right projection
video.close()

if save_video and compress:
# Compress the video; might fail, and is a basic command that can also be better optimized
gen_utils.compress_video(original_video=f'{result_name}.mp4',
gen_utils.compress_video(original_video=f'{result_name}_{num_steps}.mp4',
original_video_name=f'{result_name.split(os.sep)[-1]}',
outdir=run_dir,
ctx=ctx)
Expand Down
3 changes: 3 additions & 0 deletions torch_utils/ops/grid_sample_gradfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

enabled = False # Enable the custom op by setting this to true.
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
_use_pytorch_1_12_api = parse_version(torch.__version__) >= parse_version('1.12.0a') # Allow prerelease builds of 1.12

#----------------------------------------------------------------------------

Expand Down Expand Up @@ -58,6 +59,8 @@ class _GridSample2dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid):
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
if _use_pytorch_1_12_api:
op = op[0]
if _use_pytorch_1_11_api:
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
Expand Down