-
Notifications
You must be signed in to change notification settings - Fork 8
/
run.py
executable file
·69 lines (57 loc) · 1.97 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python
import sys
import os
import time
import subprocess as sp
import itertools
## CV
import cv2
## Model
import numpy as np
import tensorflow as tf
## Tools
import utils
## Parameters
import params ## you can modify the content of params.py
## Test epoch
epoch_ids = [10]
## Load model
model = utils.get_model()
## Preprocess
def img_pre_process(img):
"""
Processes the image and returns it
:param img: The image to be processed
:return: Returns the processed image
"""
## Chop off 1/3 from the top and cut bottom 150px(which contains the head of car)
shape = img.shape
img = img[int(shape[0]/3):shape[0]-150, 0:shape[1]]
## Resize the image
img = cv2.resize(img, (params.FLAGS.img_w, params.FLAGS.img_h), interpolation=cv2.INTER_AREA)
## Return the image sized as a 4D array
return np.resize(img, (params.FLAGS.img_w, params.FLAGS.img_h, params.FLAGS.img_c))
## Process video
for epoch_id in epoch_ids:
print('---------- processing video for epoch {} ----------'.format(epoch_id))
vid_path = utils.join_dir(params.data_dir, 'epoch{:0>2}_front.mkv'.format(epoch_id))
assert os.path.isfile(vid_path)
frame_count = utils.frame_count(vid_path)
cap = cv2.VideoCapture(vid_path)
machine_steering = []
print('performing inference...')
time_start = time.time()
for frame_id in range(frame_count):
ret, img = cap.read()
assert ret
## you can modify here based on your model
img = img_pre_process(img)
img = img[None,:,:,:]
deg = float(model.predict(img, batch_size=1))
machine_steering.append(deg)
cap.release()
fps = frame_count / (time.time() - time_start)
print('completed inference, total frames: {}, average fps: {} Hz'.format(frame_count, round(fps, 1)))
print('performing visualization...')
utils.visualize(epoch_id, machine_steering, params.out_dir,
verbose=True, frame_count_limit=None)