-
Notifications
You must be signed in to change notification settings - Fork 151
/
gradiodemo.py
84 lines (67 loc) · 2.55 KB
/
gradiodemo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from torchvision import transforms, utils
from tqdm import tqdm
torch.backends.cudnn.benchmark = True
import copy
from util import *
from PIL import Image
from model import *
import moviepy.video.io.ImageSequenceClip
import scipy
import kornia.augmentation as K
from base64 import b64encode
import gradio as gr
from torchvision import transforms
torch.hub.download_url_to_file('https://i.imgur.com/HiOTPNg.png', 'mona.png')
torch.hub.download_url_to_file('https://i.imgur.com/Cw8HcTN.png', 'painting.png')
device = 'cpu'
latent_dim = 8
n_mlp = 5
num_down = 3
G_A2B = Generator(256, 4, latent_dim, n_mlp, channel_multiplier=1, lr_mlp=.01,n_res=1).to(device).eval()
ensure_checkpoint_exists('GNR_checkpoint.pt')
ckpt = torch.load('GNR_checkpoint.pt', map_location=device)
G_A2B.load_state_dict(ckpt['G_A2B_ema'])
# mean latent
truncation = 1
with torch.no_grad():
mean_style = G_A2B.mapping(torch.randn([1000, latent_dim]).to(device)).mean(0, keepdim=True)
test_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True)
])
plt.rcParams['figure.dpi'] = 200
torch.manual_seed(84986)
num_styles = 1
style = torch.randn([num_styles, latent_dim]).to(device)
def inference(input_im):
real_A = test_transform(input_im).unsqueeze(0).to(device)
with torch.no_grad():
A2B_content, _ = G_A2B.encode(real_A)
fake_A2B = G_A2B.decode(A2B_content.repeat(num_styles,1,1,1), style)
std=(0.5, 0.5, 0.5)
mean=(0.5, 0.5, 0.5)
z = fake_A2B * torch.tensor(std).view(3, 1, 1)
z = z + torch.tensor(mean).view(3, 1, 1)
tensor_to_pil = transforms.ToPILImage(mode='RGB')(z.squeeze())
return tensor_to_pil
title = "GANsNRoses"
description = "demo for GANsNRoses. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2106.06561'>GANs N' Roses: Stable, Controllable, Diverse Image to Image Translation (works for videos too!)</a> | <a href='https://github.com/mchong6/GANsNRoses'>Github Repo</a></p>"
gr.Interface(
inference,
[gr.inputs.Image(type="pil", label="Input")],
gr.outputs.Image(type="pil", label="Output"),
title=title,
description=description,
article=article,
examples=[
["mona.png"],
["painting.png"]
]).launch()