-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata_collect.py
167 lines (136 loc) · 5.07 KB
/
data_collect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import subprocess
import tempfile
import shutil
import random
import argparse
import logging
import fnmatch
from PIL import Image
import sys
LOG_FILE = "data_collect.log"
DATA_FILE = "out.data"
def shell(command, cwd=None, shell=False):
return subprocess.check_output(
command,
cwd=cwd,
stderr=subprocess.STDOUT,
shell=shell,
)
def jpgToRgb(inPath, outPath):
jpg = Image.open(inPath)
pix = jpg.load()
rgb = open(outPath, "w")
rgb.write(str(jpg.size[0]) + "," + str(jpg.size[1]) + "\n")
for y in range(0, jpg.size[1]):
rgb.write(str(pix[0, y][0]) + "," + str(pix[0, y][1]) + "," + str(pix[0, y][2]))
for x in range(1, jpg.size[0]):
rgb.write("," + str(pix[x, y][0]) + "," + str(pix[x, y][1]) + "," + str(pix[x, y][2]))
rgb.write("\n")
def collect(imdir, outfile, window):
jpegs = []
for root, dirnames, filenames in os.walk(imdir):
for filename in fnmatch.filter(filenames, '*.jpg'):
jpegs.append(os.path.join(root, filename))
logging.info('Found {} images in {}'.format(len(jpegs), imdir))
rgbDir = tempfile.mkdtemp()+'/'
logging.debug('New directory created: {}'.format(rgbDir))
rgbs = []
for jpeg in jpegs:
fn = os.path.basename(jpeg)
rgb = rgbDir+os.path.splitext(fn)[0]+".rgb"
logging.debug('Converting {} to rgb format'.format(fn))
jpgToRgb(jpeg, rgb)
rgbs.append(rgb)
logging.info('Converted {} jpegs to rgb format'.format(len(rgbs)))
logging.info('Now compiling face detection program with window size set to {}'.format(window))
try:
defineStr = "DEFINES=\"-DDEFSIZE="+str(window)+"\""
shell(["make", defineStr])
except:
logging.error('Compiling face detection program failed')
exit()
logging.info('Compiled face detection program successfully')
trainingSamples = []
for rgb in rgbs:
dataFile = rgbDir+os.path.splitext(os.path.basename(rgb))[0]
try:
logging.debug('Running face detection and data collection on {}'.format(dataFile))
shell(["./detect", rgb, dataFile])
except:
logging.error('Face detection on {} failed'.format(rgb))
exit()
trainingSamples += process(dataFile)
with open(outfile, 'w') as f:
f.write("{} {} {}\n".format(len(trainingSamples), (window*window), 1))
for dat in trainingSamples:
f.write("{}\n{}\n".format(dat[0], dat[1]))
shutil.rmtree(rgbDir)
def merge(large, small):
tempData = []
sampleIndices = random.sample(xrange(len(large)), len(small))
tempData = [large[i] for i in sampleIndices]
tempData += small
random.shuffle(tempData)
return tempData
def process(dataFile):
pDataFile = dataFile+'.pos.data'
nDataFile = dataFile+'.neg.data'
pData = []
nData = []
for dataSet in [[pDataFile, pData, 1], [nDataFile, nData, 0]]:
with open(dataSet[0]) as f:
for line in f:
dataSet[1].append([line.strip(), dataSet[2]])
logging.debug("Obtained {} positive samples, and {} negative samples".format(len(pData), len(nData)))
mergedData = []
if len(nData) > len(pData):
mergedData = merge(nData, pData)
else:
mergedData = merge(pData, nData)
return mergedData
def cli():
parser = argparse.ArgumentParser(
description='Face detection training data collection script'
)
parser.add_argument(
'-dir', dest='imdir', action='store', type=str, required=False,
default=".", help='image dataset directory'
)
parser.add_argument(
'-w', dest='window', action='store', type=int, required=False,
default=20, help='window size of the training data (default 20)'
)
parser.add_argument(
'-d', dest='debug', action='store_true', required=False,
default=False, help='print out debug messages'
)
parser.add_argument(
'-log', dest='logpath', action='store', type=str, required=False,
default=LOG_FILE, help='path to log file'
)
parser.add_argument(
'-out', dest='outfile', action='store', type=str, required=False,
default=DATA_FILE, help='path to output data file'
)
args = parser.parse_args()
# Take care of log formatting
logFormatter = logging.Formatter("%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s", datefmt='%m/%d/%Y %I:%M:%S %p')
rootLogger = logging.getLogger()
open(args.logpath, 'a').close()
fileHandler = logging.FileHandler(args.logpath)
fileHandler.setFormatter(logFormatter)
rootLogger.addHandler(fileHandler)
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(logFormatter)
rootLogger.addHandler(consoleHandler)
if(args.debug):
rootLogger.setLevel(logging.DEBUG)
else:
rootLogger.setLevel(logging.INFO)
if (os.path.isdir(args.imdir)):
collect(args.imdir, args.outfile, args.window)
else:
print ("Error: Directory {} does not exist".format(args.imdir))
if __name__ == '__main__':
cli()