diff --git a/jfda/lnet.py b/jfda/lnet.py index 8ad3864..c8359e5 100644 --- a/jfda/lnet.py +++ b/jfda/lnet.py @@ -11,7 +11,7 @@ import caffe import numpy as np from caffe.proto import caffe_pb2 -import google.protobuf as pb2 +from google.protobuf import text_format from jfda.config import cfg from jfda.utils import load_celeba, get_logger, crop_face @@ -191,7 +191,7 @@ def get_data_size(txt): final_model = 'tmp/lnet_iter_%d.caffemodel'%max_iter solver_param = caffe_pb2.SolverParameter() with open('proto/l_solver.prototxt', 'r') as fin: - pb2.text_format.Merge(fin.read(), solver_param) + text_format.Merge(fin.read(), solver_param) solver_param.max_iter = max_iter solver_param.snapshot = iter_train solver_param.test_interval = iter_train @@ -201,7 +201,7 @@ def get_data_size(txt): solver_param.stepsize = args.lrp * iter_train tmp_solver_prototxt = 'tmp/l_solver.prototxt' with open(tmp_solver_prototxt, 'w') as fout: - fout.write(pb2.text_format.MessageToString(solver_param)) + fout.write(text_format.MessageToString(solver_param)) # solver setup solver = caffe.SGDSolver(tmp_solver_prototxt) # train diff --git a/jfda/train.py b/jfda/train.py index d593791..fabd5ae 100644 --- a/jfda/train.py +++ b/jfda/train.py @@ -7,7 +7,7 @@ import numpy as np import caffe from caffe.proto import caffe_pb2 -import google.protobuf as pb2 +from google.protobuf import text_format from jfda.config import cfg from jfda.minibatch import MiniBatcher @@ -45,7 +45,7 @@ def __init__(self, solver_prototxt, args): self.final_model = 'tmp/%snet_iter_%d.caffemodel'%(net_type, max_iter) solver_param = caffe_pb2.SolverParameter() with open(solver_prototxt, 'r') as fin: - pb2.text_format.Merge(fin.read(), solver_param) + text_format.Merge(fin.read(), solver_param) solver_param.max_iter = max_iter # max training iterations solver_param.snapshot = iter_train # save after an epoch solver_param.test_interval = iter_train @@ -56,7 +56,7 @@ def __init__(self, solver_prototxt, args): solver_param.weight_decay = args.wd tmp_solver_prototxt = 'tmp/%s_solver.prototxt'%net_type with open(tmp_solver_prototxt, 'w') as fout: - fout.write(pb2.text_format.MessageToString(solver_param)) + fout.write(text_format.MessageToString(solver_param)) # solver setup self.solver = caffe.SGDSolver(tmp_solver_prototxt) # data layer setup diff --git a/mxnet/model.py b/mxnet/model.py index fca55ef..c4c44e8 100644 --- a/mxnet/model.py +++ b/mxnet/model.py @@ -97,5 +97,4 @@ def lnet(): mx.viz.plot_network(p, shape={'data': (1, 3, 12, 12)}).render('tmp/pnet') mx.viz.plot_network(r, shape={'data': (1, 3, 24, 24)}).render('tmp/rnet') mx.viz.plot_network(o, shape={'data': (1, 3, 48, 48)}).render('tmp/onet') - # mx.viz doesn't support multi-output from an op - #mx.viz.plot_network(l, shape={'data': (1, 15, 24, 24)}).render('tmp/lnet') + mx.viz.plot_network(l, shape={'data': (1, 15, 24, 24)}).render('tmp/lnet') diff --git a/train.sh b/train.sh index 5b6e205..7ae37f4 100644 --- a/train.sh +++ b/train.sh @@ -1,27 +1,37 @@ #!/usr/bin/env bash +set -e GPU=0 +GL=$GLOG_minloglevel # pnet echo "Generate Data for pNet" +export GLOG_minloglevel=2 python jfda/prepare.py --net p --wider --celeba --worker 8 echo "Train pNet" +export GLOG_minloglevel=$GL python jfda/train.py --net p --gpu $GPU --size 128 --lr 0.05 --lrw 0.1 --lrp 5 --wd 0.0001 --epoch 25 # rnet echo "Generate Data for rNet" +export GLOG_minloglevel=2 python jfda/prepare.py --net r --gpu $GPU --detect --celeba --wider --worker 4 echo "Train rNet" +export GLOG_minloglevel=$GL python jfda/train.py --net r --gpu $GPU --size 128 --lr 0.05 --lrw 0.1 --lrp 5 --wd 0.0001 --epoch 25 # onet echo "Generate Data for oNet" +export GLOG_minloglevel=2 python jfda/prepare.py --net o --gpu $GPU --detect --celeba --wider --worker 4 echo "Train oNet" +export GLOG_minloglevel=$GL python jfda/train.py --net o --gpu $GPU --size 64 --lr 0.05 --lrw 0.1 --lrp 7 --wd 0.0001 --epoch 35 # lnet echo "Generate Data for lNet" +export GLOG_minloglevel=2 python jfda/lnet.py --prepare --worker 8 echo "Train lNet" +export GLOG_minloglevel=$GL python jfda/lnet.py --train --gpu $GPU --lr 0.1 --lrw 0.1 --lrp 2 --epoch 10