-
Notifications
You must be signed in to change notification settings - Fork 39
/
posenet.js
166 lines (142 loc) · 4.7 KB
/
posenet.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import * as posenet from '@tensorflow-models/posenet';
import {drawKeypoints, drawSkeleton} from './demo_util';
import Transform from './tranform';
const videoWidth = 500;
const videoHeight = 500;
navigator.getUserMedia = navigator.getUserMedia ||
navigator.webkitGetUserMedia || navigator.mozGetUserMedia;
/**
* Posenet class for loading posenet
* and running inferences on it
*/
export default class PoseNet{
/**
* the class constructor
* @param {Joints} joints processes raw joints data from posenet
* @param {GraphicsEngine} graphicsEngine to which joints data will be fed
* @param {array} _htmlelems that will be used to present results
*/
constructor(joints, graphicsEngine, _htmlelems){
this.state = {
algorithm: 'single-pose',
input: {
outputStride: 16,
imageScaleFactor: 0.5,
},
singlePoseDetection: {
minPoseConfidence: 0.1,
minPartConfidence: 0.5,
},
net: null,
};
this.htmlElements = _htmlelems;
this.joints = joints;
this.transform = new Transform(this.joints);
this.graphics_engine = graphicsEngine;
this.graphics_engine.render();
}
/** Checks whether the device is mobile or not */
isMobile() {
const mobile = /Android/i.test(navigator.userAgent) || /iPhone|iPad|iPod/i.test(navigator.userAgent);
return mobile;
}
/** Starts webcam video */
async loadVideo() {
const video = await this.setupCamera();
video.play();
return video;
}
/** Sets uo webcam */
async setupCamera() {
if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
throw new Error(
'Browser API navigator.mediaDevices.getUserMedia not available');
}
const video = this.htmlElements.video;
video.width = videoWidth;
video.height = videoHeight;
const mobile = this.isMobile();
const stream = await navigator.mediaDevices.getUserMedia({
'audio': false,
'video': {
facingMode: 'user',
width: mobile ? undefined : videoWidth,
height: mobile ? undefined : videoHeight,
},
});
video.srcObject = stream;
return new Promise((resolve) => {
video.onloadedmetadata = () => {
resolve(video);
};
});
}
/**
* Detects human pse from video stream using posenet
* @param {VideoObject} video
* @param {TFModel} net
*/
detectPoseInRealTime(video, net) {
const canvas = this.htmlElements.output;
const ctx = canvas.getContext('2d');
// since images are being fed from a webcam
const flipHorizontal = true;
canvas.width = videoWidth;
canvas.height = videoHeight;
const self = this;
async function poseDetectionFrame() {
// Scale an image down to a certain factor. Too large of an image will slow
// down the GPU
const imageScaleFactor = self.state.input.imageScaleFactor;
const outputStride = +self.state.input.outputStride;
let poses = [];
let minPoseConfidence;
let minPartConfidence;
const pose = await self.net.estimateSinglePose(
video, imageScaleFactor, flipHorizontal, outputStride);
poses.push(pose);
minPoseConfidence = +self.state.singlePoseDetection.minPoseConfidence;
minPartConfidence = +self.state.singlePoseDetection.minPartConfidence;
ctx.clearRect(0, 0, videoWidth, videoHeight);
ctx.save();
ctx.scale(-1, 1);
ctx.translate(-videoWidth, 0);
ctx.drawImage(video, 0, 0, videoWidth, videoHeight);
ctx.restore();
// For each pose (i.e. person) detected in an image, loop through the poses
// and draw the resulting skeleton and keypoints if over certain confidence
// scores
poses.forEach(({score, keypoints}) => {
if (score >= minPoseConfidence) {
self.transform.updateKeypoints(keypoints, minPartConfidence);
const head = self.transform.head();
const shouldMoveFarther = drawKeypoints(keypoints.slice(0, 7), minPartConfidence, ctx);
if (shouldMoveFarther){
ctx.font = "30px Arial";
ctx.fillText("Please Move Farther", Math.round(videoHeight / 2) - 100, Math.round(videoWidth / 2));
}
drawSkeleton(keypoints, minPartConfidence, ctx);
}
});
requestAnimationFrame(poseDetectionFrame);
}
poseDetectionFrame();
}
/** Loads the PoseNet model weights with architecture 0.75 */
async loadNetwork(){
this.net = await posenet.load();
}
/**
* Starts predicting human pose from webcam
*/
async startPrediction() {
let video;
try {
video = await this.loadVideo();
} catch (e) {
return false;
}
this.detectPoseInRealTime(video, this.net);
return true;
}
}