Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyetx committed Feb 1, 2018
1 parent 8da0c49 commit 8ef2656
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 8 deletions.
6 changes: 3 additions & 3 deletions jfda/lnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import caffe
import numpy as np
from caffe.proto import caffe_pb2
import google.protobuf as pb2
from google.protobuf import text_format
from jfda.config import cfg
from jfda.utils import load_celeba, get_logger, crop_face

Expand Down Expand Up @@ -191,7 +191,7 @@ def get_data_size(txt):
final_model = 'tmp/lnet_iter_%d.caffemodel'%max_iter
solver_param = caffe_pb2.SolverParameter()
with open('proto/l_solver.prototxt', 'r') as fin:
pb2.text_format.Merge(fin.read(), solver_param)
text_format.Merge(fin.read(), solver_param)
solver_param.max_iter = max_iter
solver_param.snapshot = iter_train
solver_param.test_interval = iter_train
Expand All @@ -201,7 +201,7 @@ def get_data_size(txt):
solver_param.stepsize = args.lrp * iter_train
tmp_solver_prototxt = 'tmp/l_solver.prototxt'
with open(tmp_solver_prototxt, 'w') as fout:
fout.write(pb2.text_format.MessageToString(solver_param))
fout.write(text_format.MessageToString(solver_param))
# solver setup
solver = caffe.SGDSolver(tmp_solver_prototxt)
# train
Expand Down
6 changes: 3 additions & 3 deletions jfda/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import caffe
from caffe.proto import caffe_pb2
import google.protobuf as pb2
from google.protobuf import text_format
from jfda.config import cfg
from jfda.minibatch import MiniBatcher

Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(self, solver_prototxt, args):
self.final_model = 'tmp/%snet_iter_%d.caffemodel'%(net_type, max_iter)
solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'r') as fin:
pb2.text_format.Merge(fin.read(), solver_param)
text_format.Merge(fin.read(), solver_param)
solver_param.max_iter = max_iter # max training iterations
solver_param.snapshot = iter_train # save after an epoch
solver_param.test_interval = iter_train
Expand All @@ -56,7 +56,7 @@ def __init__(self, solver_prototxt, args):
solver_param.weight_decay = args.wd
tmp_solver_prototxt = 'tmp/%s_solver.prototxt'%net_type
with open(tmp_solver_prototxt, 'w') as fout:
fout.write(pb2.text_format.MessageToString(solver_param))
fout.write(text_format.MessageToString(solver_param))
# solver setup
self.solver = caffe.SGDSolver(tmp_solver_prototxt)
# data layer setup
Expand Down
3 changes: 1 addition & 2 deletions mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,4 @@ def lnet():
mx.viz.plot_network(p, shape={'data': (1, 3, 12, 12)}).render('tmp/pnet')
mx.viz.plot_network(r, shape={'data': (1, 3, 24, 24)}).render('tmp/rnet')
mx.viz.plot_network(o, shape={'data': (1, 3, 48, 48)}).render('tmp/onet')
# mx.viz doesn't support multi-output from an op
#mx.viz.plot_network(l, shape={'data': (1, 15, 24, 24)}).render('tmp/lnet')
mx.viz.plot_network(l, shape={'data': (1, 15, 24, 24)}).render('tmp/lnet')
10 changes: 10 additions & 0 deletions train.sh
Original file line number Diff line number Diff line change
@@ -1,27 +1,37 @@
#!/usr/bin/env bash

set -e
GPU=0
GL=$GLOG_minloglevel

# pnet
echo "Generate Data for pNet"
export GLOG_minloglevel=2
python jfda/prepare.py --net p --wider --celeba --worker 8
echo "Train pNet"
export GLOG_minloglevel=$GL
python jfda/train.py --net p --gpu $GPU --size 128 --lr 0.05 --lrw 0.1 --lrp 5 --wd 0.0001 --epoch 25

# rnet
echo "Generate Data for rNet"
export GLOG_minloglevel=2
python jfda/prepare.py --net r --gpu $GPU --detect --celeba --wider --worker 4
echo "Train rNet"
export GLOG_minloglevel=$GL
python jfda/train.py --net r --gpu $GPU --size 128 --lr 0.05 --lrw 0.1 --lrp 5 --wd 0.0001 --epoch 25

# onet
echo "Generate Data for oNet"
export GLOG_minloglevel=2
python jfda/prepare.py --net o --gpu $GPU --detect --celeba --wider --worker 4
echo "Train oNet"
export GLOG_minloglevel=$GL
python jfda/train.py --net o --gpu $GPU --size 64 --lr 0.05 --lrw 0.1 --lrp 7 --wd 0.0001 --epoch 35

# lnet
echo "Generate Data for lNet"
export GLOG_minloglevel=2
python jfda/lnet.py --prepare --worker 8
echo "Train lNet"
export GLOG_minloglevel=$GL
python jfda/lnet.py --train --gpu $GPU --lr 0.1 --lrw 0.1 --lrp 2 --epoch 10

0 comments on commit 8ef2656

Please sign in to comment.