-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add scripts to convert model to mxnet
- Loading branch information
Showing
3 changed files
with
323 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
#!/usr/bin/env python2.7 | ||
|
||
import argparse | ||
import caffe | ||
import mxnet as mx | ||
import numpy as np | ||
from model import pnet, rnet, onet, lnet | ||
|
||
|
||
def get_params(caffe_param): | ||
"""Get all params from caffe, layer type is Convolution, PReLU, InnerProduct | ||
""" | ||
arg_params = {} | ||
for k, v in caffe_param.iteritems(): | ||
if 'conv' in k: | ||
# Convolution | ||
arg_params[k+'_weight'] = mx.nd.array(v[0].data) | ||
arg_params[k+'_bias'] = mx.nd.array(v[1].data) | ||
elif 'prelu' in k: | ||
# PReLU | ||
arg_params[k+'_gamma'] = mx.nd.array(v[0].data) | ||
else: | ||
# InnerProduct | ||
arg_params[k+'_weight'] = mx.nd.array(v[0].data) | ||
arg_params[k+'_bias'] = mx.nd.array(v[1].data) | ||
return arg_params | ||
|
||
|
||
def test_net(mx_net, caffe_net, data): | ||
"""test network | ||
""" | ||
caffe_net.blobs['data'].reshape(*data.shape) | ||
caffe_net.blobs['data'].data[...] = data | ||
caffe_net.forward() | ||
caffe_prob = caffe_net.blobs['prob'].data | ||
caffe_bbox = caffe_net.blobs['bbox_pred'].data | ||
caffe_landmark = caffe_net.blobs['landmark_pred'].data | ||
batch = mx.io.DataBatch(data=[mx.nd.array(data)], label=None) | ||
mx_net.forward(batch, is_train=False) | ||
mx_prob, mx_bbox, mx_landmark = [x.asnumpy() for x in mx_net.get_outputs()] | ||
mse = lambda x, y: np.square(x-y).mean() | ||
print 'prob mse:', mse(caffe_prob, mx_prob) | ||
print 'bbox mse:', mse(caffe_bbox, mx_bbox) | ||
print 'landmark mse:', mse(caffe_landmark, mx_landmark) | ||
|
||
|
||
def test_lnet(mx_net, caffe_net, data): | ||
"""test lnet | ||
""" | ||
caffe_net.blobs['data'].reshape(*data.shape) | ||
caffe_net.blobs['data'].data[...] = data | ||
caffe_net.forward() | ||
caffe_offset = caffe_net.blobs['landmark_offset'].data | ||
batch = mx.io.DataBatch(data=[mx.nd.array(data)], label=None) | ||
mx_net.forward(batch, is_train=False) | ||
mx_offset = mx_net.get_outputs()[0].asnumpy() | ||
mse = lambda x, y: np.square(x-y).mean() | ||
print 'landmark offset mse:', mse(caffe_offset, mx_offset) | ||
|
||
|
||
def convert(net_type, args): | ||
"""Convert a network | ||
""" | ||
if net_type == 'pnet': | ||
mx_net = pnet() | ||
caffe_net = caffe.Net(args.proto_dir + '/p.prototxt', caffe.TEST, weights=args.model_dir + '/p.caffemodel') | ||
input_channel = 3 | ||
input_size = 12 | ||
mode_prefix = 'tmp/pnet' | ||
elif net_type == 'rnet': | ||
mx_net = rnet() | ||
caffe_net = caffe.Net(args.proto_dir + '/r.prototxt', caffe.TEST, weights=args.model_dir + '/r.caffemodel') | ||
input_channel = 3 | ||
input_size = 24 | ||
mode_prefix = 'tmp/rnet' | ||
elif net_type == 'onet': | ||
mx_net = onet() | ||
caffe_net = caffe.Net(args.proto_dir + '/o.prototxt', caffe.TEST, weights=args.model_dir + '/o.caffemodel') | ||
input_channel = 3 | ||
input_size = 48 | ||
mode_prefix = 'tmp/onet' | ||
elif net_type == 'lnet': | ||
mx_net = lnet() | ||
caffe_net = caffe.Net(args.proto_dir + '/l.prototxt', caffe.TEST, weights=args.model_dir + '/l.caffemodel') | ||
input_channel = 15 | ||
input_size = 24 | ||
mode_prefix = 'tmp/lnet' | ||
else: | ||
raise ValueError("No such net type (%s)"%net_type) | ||
|
||
arg_params = get_params(caffe_net.params) | ||
mx_mod = mx.mod.Module(symbol=mx_net, data_names=('data'), label_names=None) | ||
mx_mod.bind(data_shapes=[('data', (100, input_channel, input_size, input_size)),]) | ||
mx_mod.set_params(arg_params=arg_params, aux_params=None, allow_missing=True) | ||
mx.model.save_checkpoint(mode_prefix, 0, mx_net, arg_params, {}) | ||
|
||
# test | ||
data = np.random.rand(100, input_channel, input_size, input_size).astype(np.float32) | ||
if net_type == 'lnet': | ||
test_lnet(mx_mod, caffe_net, data) | ||
else: | ||
test_net(mx_mod, caffe_net, data) | ||
|
||
return mx_mod, caffe_net | ||
|
||
|
||
def main(args): | ||
# pnet | ||
print 'convert pnet' | ||
convert('pnet', args) | ||
# rnet | ||
print 'convert rnet' | ||
convert('rnet', args) | ||
# onet | ||
print 'convert onet' | ||
convert('onet', args) | ||
# lnet | ||
print 'convert lnet' | ||
convert('lnet', args) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--proto-dir', type=str, default='../proto', help="caffe proto directory") | ||
parser.add_argument('--model-dir', type=str, default='../model', help="caffe mode directory") | ||
parser.add_argument('--out-dir', type=str, default='./tmp', help="mxnet output model directory") | ||
args = parser.parse_args() | ||
print args | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#!/usr/bin/env python2.7 | ||
|
||
import os | ||
import math | ||
import argparse | ||
import cv2 | ||
import numpy as np | ||
import mxnet as mx | ||
from jfda.detector import JfdaDetector | ||
from jfda.utils import crop_face, Timer | ||
|
||
|
||
class MxDetector(JfdaDetector): | ||
"""JfdaDetector using mxnet | ||
""" | ||
|
||
def __init__(self, model_dir='./tmp', ctx=mx.cpu()): | ||
self.pnet = mx.model.FeedForward.load(model_dir+'/pnet', 0, ctx=ctx) | ||
self.rnet = mx.model.FeedForward.load(model_dir+'/rnet', 0, ctx=ctx) | ||
self.onet = mx.model.FeedForward.load(model_dir+'/onet', 0, ctx=ctx) | ||
self.lnet = mx.model.FeedForward.load(model_dir+'/lnet', 0, ctx=ctx) | ||
|
||
def _forward(self, net, data, outs): | ||
'''forward a net with given data, return blobs[out] | ||
''' | ||
output = net.predict(data) | ||
if not isinstance(output, list): | ||
output = [output] | ||
return output | ||
|
||
def _clear_network_buffer(self, net): | ||
pass | ||
|
||
|
||
def main(args): | ||
ctx = mx.gpu(args.gpu) if args.gpu >= 0 else mx.cpu() | ||
detector = MxDetector(ctx=ctx) | ||
param = { | ||
'ths': [0.6, 0.7, 0.8], | ||
'factor': 0.709, | ||
'min_size': 24, | ||
} | ||
timer = Timer() | ||
|
||
def gen(img, bboxes, outname): | ||
for i in range(len(bboxes)): | ||
x1, y1, x2, y2, score = bboxes[i, :5] | ||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | ||
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2) | ||
cv2.putText(img, '%.03f'%score, (x1, y1), cv2.FONT_HERSHEY_PLAIN, 1, (0, 255, 0)) | ||
# landmark | ||
landmark = bboxes[i, 9:].reshape((5, 2)) | ||
for j in range(5): | ||
x, y = landmark[j] | ||
x, y = int(x), int(y) | ||
cv2.circle(img, (x, y), 2, (0, 255, 0), -1) | ||
cv2.imwrite(outname, img) | ||
|
||
with open('./tmp/demo.txt', 'r') as fin: | ||
for line in fin.readlines(): | ||
fp = line.strip() | ||
dn = os.path.dirname(fp) | ||
fn = os.path.basename(fp).split('.')[0] | ||
img = cv2.imread(fp, cv2.IMREAD_COLOR) | ||
timer.tic() | ||
bb, ts = detector.detect(img, debug=True, **param) | ||
timer.toc() | ||
print 'detect %s costs %.04lfs'%(fp, timer.elapsed()) | ||
print 'image size = (%d x %d), s1: %.04lfs, s2: %.04lfs, s3: %.04lfs, s4: %.04lf'%( | ||
img.shape[0], img.shape[1], ts[0], ts[1], ts[2], ts[3]) | ||
print 'bboxes, s1: %d, s2: %d, s3: %d, s4: %d'%(len(bb[0]), len(bb[1]), len(bb[2]), len(bb[3])) | ||
out1 = '%s/%s_stage1.jpg'%(dn, fn) | ||
out2 = '%s/%s_stage2.jpg'%(dn, fn) | ||
out3 = '%s/%s_stage3.jpg'%(dn, fn) | ||
out4 = '%s/%s_stage4.jpg'%(dn, fn) | ||
gen(img.copy(), bb[0], out1) | ||
gen(img.copy(), bb[1], out2) | ||
gen(img.copy(), bb[2], out3) | ||
gen(img.copy(), bb[3], out4) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--gpu', type=int, default=-1, help='gpu id to use, -1 for cpu') | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import mxnet as mx | ||
|
||
|
||
def pnet(): | ||
data = mx.sym.Variable('data') | ||
conv1 = mx.sym.Convolution(data=data, kernel=(3, 3), num_filter=10, name='conv1') | ||
prelu1 = mx.sym.LeakyReLU(data=conv1, act_type='prelu', name='prelu1') | ||
pool1 = mx.sym.Pooling(data=prelu1, kernel=(2, 2), stride=(2, 2), pool_type='max', \ | ||
pooling_convention='full', name='pool1') | ||
conv2 = mx.sym.Convolution(data=pool1, kernel=(3, 3), num_filter=16, name='conv2') | ||
prelu2 = mx.sym.LeakyReLU(data=conv2, act_type='prelu', name='prelu2') | ||
conv3 = mx.sym.Convolution(data=prelu2, kernel=(3, 3), num_filter=32, name='conv3') | ||
prelu3 = mx.sym.LeakyReLU(data=conv3, act_type='prelu', name='prelu3') | ||
score = mx.sym.Convolution(data=prelu3, kernel=(1, 1), num_filter=2, name='score') | ||
prob = mx.sym.SoftmaxActivation(data=score, mode='channel', name='prob') | ||
bbox_pred = mx.sym.Convolution(data=prelu3, kernel=(1, 1), num_filter=4, name='bbox_pred') | ||
landmark_pred = mx.sym.Convolution(data=prelu3, kernel=(1, 1), num_filter=10, name='landmark_pred') | ||
out = mx.sym.Group([prob, bbox_pred, landmark_pred]) | ||
return out | ||
|
||
|
||
def rnet(): | ||
data = mx.sym.Variable('data') | ||
conv1 = mx.sym.Convolution(data=data, kernel=(3, 3), num_filter=28, name='conv1') | ||
prelu1 = mx.sym.LeakyReLU(data=conv1, act_type='prelu', name='prelu1') | ||
pool1 = mx.sym.Pooling(data=prelu1, kernel=(3, 3), stride=(2, 2), pool_type='max', \ | ||
pooling_convention='full', name='pool1') | ||
conv2 = mx.sym.Convolution(data=pool1, kernel=(3, 3), num_filter=48, name='conv2') | ||
prelu2 = mx.sym.LeakyReLU(data=conv2, act_type='prelu', name='prelu2') | ||
pool2 = mx.sym.Pooling(data=prelu2, kernel=(3, 3), stride=(2, 2), pool_type='max', \ | ||
pooling_convention='full', name='pool2') | ||
conv3 = mx.sym.Convolution(data=pool2, kernel=(2, 2), num_filter=64, name='conv3') | ||
prelu3 = mx.sym.LeakyReLU(data=conv3, act_type='prelu', name='prelu3') | ||
fc = mx.sym.FullyConnected(data=prelu3, num_hidden=128, name='fc') | ||
prelu4 = mx.sym.LeakyReLU(data=fc, act_type='prelu', name='prelu4') | ||
score = mx.sym.FullyConnected(data=prelu4, num_hidden=2, name='score') | ||
prob = mx.sym.SoftmaxActivation(data=score, name='prob') | ||
bbox_pred = mx.sym.FullyConnected(data=prelu4, num_hidden=4, name='bbox_pred') | ||
landmark_pred = mx.sym.FullyConnected(data=prelu4, num_hidden=10, name='landmark_pred') | ||
out = mx.sym.Group([prob, bbox_pred, landmark_pred]) | ||
return out | ||
|
||
|
||
def onet(): | ||
data = mx.sym.Variable('data') | ||
conv1 = mx.sym.Convolution(data=data, kernel=(3, 3), num_filter=32, name='conv1') | ||
prelu1 = mx.sym.LeakyReLU(data=conv1, act_type='prelu', name='prelu1') | ||
pool1 = mx.sym.Pooling(data=prelu1, kernel=(3, 3), stride=(2, 2), pool_type='max', \ | ||
pooling_convention='full', name='pool1') | ||
conv2 = mx.sym.Convolution(data=pool1, kernel=(3, 3), num_filter=64, name='conv2') | ||
prelu2 = mx.sym.LeakyReLU(data=conv2, act_type='prelu', name='prelu2') | ||
pool2 = mx.sym.Pooling(data=prelu2, kernel=(3, 3), stride=(2, 2), pool_type='max', \ | ||
pooling_convention='full', name='pool2') | ||
conv3 = mx.sym.Convolution(data=pool2, kernel=(2, 2), num_filter=64, name='conv3') | ||
prelu3 = mx.sym.LeakyReLU(data=conv3, act_type='prelu', name='prelu3') | ||
pool3 = mx.sym.Pooling(data=prelu3, kernel=(2, 2), stride=(2, 2), pool_type='max', \ | ||
pooling_convention='full', name='pool3') | ||
conv4 = mx.sym.Convolution(data=pool3, kernel=(2, 2), num_filter=128, name='conv4') | ||
prelu4 = mx.sym.LeakyReLU(data=conv4, act_type='prelu', name='prelu4') | ||
fc = mx.sym.FullyConnected(data=prelu4, num_hidden=256, name='fc') | ||
prelu5 = mx.sym.LeakyReLU(data=fc, act_type='prelu', name='prelu5') | ||
score = mx.sym.FullyConnected(data=prelu5, num_hidden=2, name='score') | ||
prob = mx.sym.SoftmaxActivation(data=score, name='prob') | ||
bbox_pred = mx.sym.FullyConnected(data=prelu5, num_hidden=4, name='bbox_pred') | ||
landmark_pred = mx.sym.FullyConnected(data=prelu5, num_hidden=10, name='landmark_pred') | ||
out = mx.sym.Group([prob, bbox_pred, landmark_pred]) | ||
return out | ||
|
||
|
||
def lnet(): | ||
data = mx.sym.Variable('data') | ||
sliced = mx.sym.SliceChannel(data=data, num_outputs=5) | ||
out = [] | ||
for i in range(1, 6): | ||
conv1 = mx.sym.Convolution(data=sliced[i-1], kernel=(3, 3), num_filter=28, name='conv1_%d'%i) | ||
prelu1 = mx.sym.LeakyReLU(data=conv1, act_type='prelu', name='prelu1_%d'%i) | ||
pool1 = mx.sym.Pooling(data=prelu1, kernel=(3, 3), stride=(2, 2), pool_type='max', \ | ||
pooling_convention='full', name='pool1_%d'%i) | ||
conv2 = mx.sym.Convolution(data=pool1, kernel=(3, 3), num_filter=48, name='conv2_%d'%i) | ||
prelu2 = mx.sym.LeakyReLU(data=conv2, act_type='prelu', name='prelu2_%d'%i) | ||
pool2 = mx.sym.Pooling(data=prelu2, kernel=(3, 3), stride=(2, 2), pool_type='max', \ | ||
pooling_convention='full', name='pool2_%d'%i) | ||
conv3 = mx.sym.Convolution(data=pool2, kernel=(2, 2), num_filter=64, name='conv3_%d'%i) | ||
prelu3 = mx.sym.LeakyReLU(data=conv3, act_type='prelu', name='prelu3_%d'%i) | ||
out.append(prelu3) | ||
concat = mx.sym.Concat(*out, name='concat') | ||
fc4 = mx.sym.FullyConnected(data=concat, num_hidden=256, name='fc4') | ||
prelu4 = mx.sym.LeakyReLU(data=fc4, act_type='prelu', name='prelu4') | ||
out = [] | ||
for i in range(1, 6): | ||
fc5 = mx.sym.FullyConnected(data=prelu4, num_hidden=64, name='fc5_%d'%i) | ||
prelu5 = mx.sym.LeakyReLU(data=fc5, act_type='prelu', name='prelu5_%d'%i) | ||
fc6 = mx.sym.FullyConnected(data=prelu5, num_hidden=2, name='fc6_%d'%i) | ||
out.append(fc6) | ||
out = mx.sym.Concat(*out, name='landmark_offset') | ||
return out | ||
|
||
|
||
if __name__ == '__main__': | ||
p = pnet() | ||
r = rnet() | ||
o = onet() | ||
l = lnet() | ||
mx.viz.plot_network(p, shape={'data': (1, 3, 12, 12)}).render('tmp/pnet') | ||
mx.viz.plot_network(r, shape={'data': (1, 3, 24, 24)}).render('tmp/rnet') | ||
mx.viz.plot_network(o, shape={'data': (1, 3, 48, 48)}).render('tmp/onet') | ||
# mx.viz doesn't support multi-output from an op | ||
#mx.viz.plot_network(l, shape={'data': (1, 15, 24, 24)}).render('tmp/lnet') |