-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
123 lines (90 loc) · 4.38 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
from options.test_options import TestOptions #calls options.test_options.py
from data import create_dataset
from models import create_model
import cv2
import torch
import numpy as np
from util import*
# Adapted from https://github.com/bensantos/webcam-CycleGAN/blob/master/webcam.py
if __name__ == '__main__':
opt = TestOptions().parse() # get test options
# hard-code some parameters for test
opt.num_threads = 0 # test code only supports num_threads = 0
opt.batch_size = 1 # test code only supports batch_size = 1
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
model = create_model(opt) # create a model given opt.model and other options
model.setup(opt) # regular setup: load and print networks; create schedulers
if opt.eval:
model.eval()
# Makes dir to store each frame from the video
# Manually change this if you are writing to a new video, otherwise it'll write over contents
os.makedirs("./result_frames", exist_ok=True)
print("Opening the video...")
vc = cv2.VideoCapture("./videos/england_3.mp4") # TODO: Pass arg to process any video
if not vc.isOpened():
raise IOError("Cannot open video.")
##########################
# video and codec params
##########################
length = int(vc.get(cv2.CAP_PROP_FRAME_COUNT))
currentFrame = 0
# Select the current video dims, or make your own (commented out below - 512x512 or 1024x1024 for example)
size = (
int(vc.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(vc.get(cv2.CAP_PROP_FRAME_HEIGHT))
)
#size = (512,512) #Use resized frame size, otherwise it will not write to file!
#size = (1024,1024)
print("size:", size)
fourcc = cv2.VideoWriter_fourcc(*'MP4V')
output = cv2.VideoWriter('output_part1.mp4', fourcc, 25.0, size) #TODO: Pass arg to name it, need to bypass CycleGAN TestOpions()
# Ref: https://towardsdatascience.com/using-cyclegan-to-perform-style-transfer-on-a-webcam-244142effe7f
# start an infinite loop and keep reading frames from the webcam until we encounter a keyboard interrupt
data = {"A": None, "A_paths": None}
while vc.isOpened():
ret, frame = vc.read()
if ret==True: #(frame read succesfuly)
currentFrame += 1
print("currentFrame:", currentFrame)
#resize frame - CycleGAN takes 256x256 only
frame = cv2.resize(frame, (256,256), interpolation=cv2.INTER_AREA)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#model wants BxCxHxW
#gives it a dimension for batch size
frame = np.array([frame])
#permute it to get: BxCxHxW
frame = frame.transpose([0,3,1,2])
#convert numpy array to tensor - this is in line with the /testA dataset
data['A'] = torch.FloatTensor(frame)
model.set_input(data) # unpack data from data loader
model.test()
#get only generated image - indexing dictionary for "fake" key
result_image = model.get_current_visuals()['fake']
#use tensor2im provided by util file
result_image = util.tensor2im(result_image)
result_image = cv2.cvtColor(np.array(result_image), cv2.COLOR_BGR2RGB)
"""
! Very important: The size you selected for the writer, must exactly
match the size of image resizing. Otherwise it will not write to output.
"""
#result_image = cv2.resize(result_image, (512, 512))
# larger res option:
result_image = cv2.resize(result_image, size)
# write to file
#cv2.imshow("summer2winter", result_image) #<- cant do on server
#output.write(result_image)
# now write to frame
cv2.imwrite("./result_frames/result_image_{}.png".format(currentFrame), result_image)
output.write(result_image)
if cv2.waitKey(10) & 0xFF == ord('q'):
break
else:
break
# ========//
# When everything is done, release the capture
vc.release()
output.release()
cv2.destroyAllWindows()