-
Notifications
You must be signed in to change notification settings - Fork 0
/
functions.py
89 lines (70 loc) · 2.63 KB
/
functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import base64
import io
from pathlib import Path
from PIL import Image
import torch.nn as nn
from torchvision import transforms
# vars
# allowed files extentions
ALLOWED_EXTENSIONS = set(['jpg', 'jpeg', 'png'])
# create output dir if it doesnt exist
OUTPUT_FOLDER = 'results'
Path(OUTPUT_FOLDER).mkdir(exist_ok=True, parents=True)
# make sure styles folder exists
STYLES_DIR = 'styles'
Path(STYLES_DIR).mkdir(exist_ok=True, parents=True)
# check if a file meets allowed extentions
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
# loss functions
def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
def mean_variance_norm(feat):
size = feat.size()
mean, std = calc_mean_std(feat)
normalized_feat = (feat - mean.expand(size)) / std.expand(size)
return normalized_feat
# transforms to perfom on image before using
def test_transform():
transform_list = []
transform_list.append(transforms.Resize(512))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
# extract style and content features
def feat_extractor(vgg, content, style, DEVICE):
# extract used layers from vgg network
enc_1 = nn.Sequential(*list(vgg.children())[:4]) # input -> relu1_1
enc_2 = nn.Sequential(*list(vgg.children())[4:11]) # relu1_1 -> relu2_1
enc_3 = nn.Sequential(*list(vgg.children())[11:18]) # relu2_1 -> relu3_1
enc_4 = nn.Sequential(*list(vgg.children())[18:31]) # relu3_1 -> relu4_1
enc_5 = nn.Sequential(*list(vgg.children())[31:44]) # relu4_1 -> relu5_1
# move everything to GPU
enc_1.to(DEVICE)
enc_2.to(DEVICE)
enc_3.to(DEVICE)
enc_4.to(DEVICE)
enc_5.to(DEVICE)
# extract content features
Content4_1 = enc_4(enc_3(enc_2(enc_1(content))))
Content5_1 = enc_5(Content4_1)
# extract style features
Style4_1 = enc_4(enc_3(enc_2(enc_1(style))))
Style5_1 = enc_5(Style4_1)
return Content4_1, Content5_1, Style4_1, Style5_1
# encode given image as base64
def get_encoded_img(image_path):
img = Image.open(image_path, mode='r')
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='PNG')
encoded_image = base64.encodebytes(
img_byte_arr.getvalue()).decode('ascii')
return encoded_image