forked from mit-han-lab/temporal-shift-module
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
65 lines (51 loc) · 2.08 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import time
import cv2
import torch
from PIL import Image
from torch.nn import Softmax
from src.camera import get_camera_capture
from src.display import setup_window, add_label, is_quit_key, is_fullscreen_key, switch_fullscreen_mode
from src.gestures import Gestures
from src.image_processing import get_transform
from src.model import load_model, init_buffer
from src.prediction_smoothing import PredictionSmoothing
if __name__ == "__main__":
model = load_model()
model.eval()
cap = get_camera_capture(0, 320, 240)
full_screen = False
window_name = 'Video Gesture Recognition'
setup_window(window_name)
transform = get_transform()
shift_buffer = init_buffer()
gestures = Gestures()
prediction_smoothing = PredictionSmoothing(7)
softmax = Softmax(1)
while True:
time_start = time.time()
_, img = cap.read()
with torch.no_grad():
pil_image = [Image.fromarray(img).convert('RGB')]
image_transformed = transform(pil_image)
input_transformed = image_transformed.view(1, 3, image_transformed.size(1), image_transformed.size(2))
predictions, *shift_buffer = model(input_transformed, *shift_buffer)
predictions = softmax(predictions)
certainty, prediction = predictions.max(1)
prediction = prediction.item()
certainty = certainty.item()
prediction_smoothing.add_prediction(prediction)
smooth_prediction = prediction_smoothing.get_most_common_prediction()
time_end = time.time()
frames_per_second = 1 / (time_end - time_start)
img = cv2.resize(img, (640, 480))
img = img[:, ::-1]
img = add_label(img, gestures.get_name(smooth_prediction), certainty, frames_per_second)
cv2.imshow(window_name, img)
key = cv2.waitKey(1)
if is_quit_key(key): # exit
break
elif is_fullscreen_key(key): # full screen
full_screen = not full_screen
switch_fullscreen_mode(full_screen, window_name)
cap.release()
cv2.destroyAllWindows()