forked from lucidrains/denoising-diffusion-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdiffusampler.py
executable file
·182 lines (136 loc) · 5.96 KB
/
diffusampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from denoising_diffusion_pytorch import GaussianDiffusion
import torch
from torchvision.utils import save_image
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import os
import clip
import argparse
import cv2
from pytorch_msssim import ssim
from postproc import pprocess
'''
pip install denoising_diffusion_pytorch
'''
from cutouts import cut
parser = argparse.ArgumentParser()
# define params and their types with defaults if needed
parser.add_argument('--steps', type=int, default=1000, help='diffusion steps')
parser.add_argument('--dir', type=str, default="out", help='base directory for storing images')
parser.add_argument('--name', type=str, default="test", help='basename for storing images')
parser.add_argument('--show', action="store_true", help='show image in a window')
parser.add_argument('--ema', action="store_true", help='use ema model')
parser.add_argument('--imageSize', type=int, default=512, help='image size')
parser.add_argument('--h', type=int, default=0, help='image height')
parser.add_argument('--w', type=int, default=0, help='image with')
parser.add_argument('--modelSize', type=int, default=512, help='native image size for model')
parser.add_argument('--saveEvery', type=int, default=0, help='image save frequency')
parser.add_argument('--saveAfter', type=int, default=0, help='save images after step')
parser.add_argument('--mults', type=int, nargs='*', default=[1, 1, 2, 2, 4, 8], help='')
parser.add_argument('--model', type=str, default="", help='model architecture: unet0, unetok5, unet1,unetcn0')
parser.add_argument('--load', type=str, default="", help='base name for checkpoint to load')
parser.add_argument('--epoch', type=int, default=0, help='first epoch')
parser.add_argument('--estep', type=int, default=0, help='epoch step')
parser.add_argument('--epoch2', type=int, default=0, help='last epoch')
parser.add_argument('--nsamples', type=int, default=0, help='images per epoch')
parser.add_argument('--batchSize', type=int, default=1, help='batch size for sampling')
parser.add_argument('--postproc', action="store_true", help='apply post processing')
parser.add_argument('--contrast', type=float, default=1, help='contrast, 1 for neutral')
parser.add_argument('--brightness', type=float, default=0, help='brightness, 0 for neutral')
parser.add_argument('--saturation', type=float, default=1, help='saturation, 1 for neutral')
parser.add_argument('--gamma', type=float, default=1, help='gamma, 1 for neutral')
parser.add_argument('--unsharp', type=float, default=0, help='unsharp mask weight')
parser.add_argument('--eqhist', type=float, default=0., help='histogram eq level')
parser.add_argument('--median', type=int, default=0, help='median blur kernel size, 0 for none')
parser.add_argument('--c1', type=float, default=0., help='do not use')
parser.add_argument('--c2', type=float, default=1., help='do not use')
parser.add_argument('--sharpenlast', action="store_true", help='do not use')
parser.add_argument('--sharpkernel', type=int, default=3, help='sharpening kernel, 0 for none')
parser.add_argument('--ovl0', type=float, default=0, help='blend original with blurred image')
parser.add_argument('--bil', type=int, default=0, help='bilateral filter kernel size')
parser.add_argument('--bils1', type=int, default=75, help='bilateral filter sigma for color')
parser.add_argument('--bils2', type=int, default=75, help='bilateral filter sigma for space')
opt = parser.parse_args()
mtype = opt.model
if opt.h == 0:
opt.h = opt.imageSize
if opt.w == 0:
opt.w = opt.imageSize
if mtype == "unet0":
from alt_models.Unet0 import Unet
elif mtype == "unet0k5":
from alt_models.Unet0k5 import Unet
elif mtype == "unet1":
from alt_models.Unet1 import Unet
elif mtype == "unet2":
from alt_models.Unet2 import Unet
elif mtype == "unetcn0":
from alt_models.UnetCN0 import Unet
else:
print("Unsupported model: "+mtype)
exit()
def show_on_screen(image_tensor, window="out", maxsize=720):
im = image_tensor.detach().numpy() # convert from pytorch tensor to numpy array
#print(im.shape)
# pytorch tensors are (C, H, W), rearrange to (H, W, C)
im = im.transpose(1, 2, 0)
# normalize range to 0 .. 1
#im = im/2 + 0.5
im -= im.min()
im /= im.max()
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
(h, w) = tuple(im.shape[:2])
if h > maxsize:
w = int(w * (maxsize/h))
h = maxsize
im = cv2.resize(im,(w, h))
# show it in a window (this will not work on a remote session)
cv2.imshow(window, im)
cv2.waitKey(100) # display for 100 ms and wait for a keypress (which we ignore here)
name = opt.name #"out5/testcd"
steps = opt.steps
bs = 1
model = Unet(
dim = 64,
dim_mults = opt.mults # (1, 2, 4, 8)
).cuda()
diffusion = GaussianDiffusion(
model,
image_size = opt.modelSize,
timesteps = steps, # number of steps
loss_type = 'l1' # L1 or L2
).cuda()
def load_model(fn):
data = torch.load(fn)
try:
print("loaded "+fn+", correct mults: "+",".join(str(x) for x in data['mults']))
except:
print("loaded "+fn+", no mults stored")
m = "ema" if opt.ema else "model"
diffusion.load_state_dict(data[m], strict=False)
return diffusion
if opt.epoch2 == 0 or opt.estep == 0:
opt.epoch2 = opt.epoch + 1
opt.estep = 1
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
if opt.postproc:
print("****** post processing not implemented yet, running without it")
for e in range(opt.epoch, opt.epoch2, opt.estep):
fn = opt.load+"-"+str(e)+".pt"
m = load_model(fn)
batches = num_to_groups(opt.nsamples, opt.batchSize)
#print(batches)
i = 0
for b in batches :
images = m.sample(batch_size=b)
for img in images:
img = (img + 1)/2
save_image(img, str(opt.dir+"/"+f'{opt.name}-{e}-{i}.png'))
i += 1