Skip to content

Commit

Permalink
added support for in-loop Caffe custom measures + generic measures an…
Browse files Browse the repository at this point in the history
…d loss output + unit tests
  • Loading branch information
beniz committed Mar 8, 2015
1 parent 0c2ee82 commit ec99bff
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 73 deletions.
13 changes: 8 additions & 5 deletions src/caffeinputconns.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,8 @@ namespace dd
std::vector<caffe::Datum> ImgCaffeInputFileConn::get_dv_test(const int &num,
const bool &has_mean_file)
{
static int pass = 0;
static std::vector<float> mean_values;
static Blob<float> data_mean;
static float *mean = nullptr;
Blob<float> data_mean;
float *mean = nullptr;
if (!_test_db_cursor)
{
// open db and create cursor
Expand Down Expand Up @@ -368,8 +366,13 @@ namespace dd
_test_db_cursor->Next();
++i;
}
++pass;
return dv;
}

void ImgCaffeInputFileConn::reset_dv_test()
{
//_test_db = std::unique_ptr<caffe::db::DB>();
_test_db_cursor = std::unique_ptr<caffe::db::Cursor>();
}

}
29 changes: 23 additions & 6 deletions src/caffeinputconns.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ namespace dd
(void)has_mean_file;
return std::vector<caffe::Datum>(num);
}

void reset_dv_test() {}

std::vector<caffe::Datum> _dv; /**< main input datum vector, used for training or prediction */
std::vector<caffe::Datum> _dv_test; /**< test input datum vector, when applicable in training mode */
Expand All @@ -71,7 +73,11 @@ namespace dd
:ImgInputFileConn() {}
ImgCaffeInputFileConn(const ImgCaffeInputFileConn &i)
:ImgInputFileConn(i),CaffeInputInterface(i) {}
~ImgCaffeInputFileConn() {}
~ImgCaffeInputFileConn()
{
if (_test_db)
_test_db->Close();
}

// size of each element in Caffe jargon
int channels() const
Expand Down Expand Up @@ -183,6 +189,8 @@ namespace dd
std::vector<caffe::Datum> get_dv_test(const int &num,
const bool &has_mean_file);

void reset_dv_test();

private:
int images_to_db(const std::string &rfolder,
const std::string &traindbname,
Expand Down Expand Up @@ -224,7 +232,10 @@ namespace dd
{
public:
CSVCaffeInputFileConn()
:CSVInputFileConn() {}
:CSVInputFileConn()
{
reset_dv_test();
}
~CSVCaffeInputFileConn() {}

void init(const APIData &ad)
Expand Down Expand Up @@ -275,20 +286,24 @@ namespace dd
std::vector<caffe::Datum> get_dv_test(const int &num,
const bool &has_mean_file)
{
static std::vector<caffe::Datum>::const_iterator vit = _dv_test.begin();
(void)has_mean_file;
int i = 0;
std::vector<caffe::Datum> dv;
while(vit!=_dv_test.end()
while(_dt_vit!=_dv_test.end()
&& i < num)
{
dv.push_back((*vit));
dv.push_back((*_dt_vit));
++i;
++vit;
++_dt_vit;
}
return dv;
}

void reset_dv_test()
{
_dt_vit = _dv_test.begin();
}

/**
* \brief turns a vector of values into a Caffe Datum structure
* @param vector of values
Expand Down Expand Up @@ -322,6 +337,8 @@ namespace dd
}
return datum;
}

std::vector<caffe::Datum>::const_iterator _dt_vit;
};

}
Expand Down
47 changes: 36 additions & 11 deletions src/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,12 @@ namespace dd
if (!inputc._dv.empty())
{
boost::dynamic_pointer_cast<caffe::MemoryDataLayer<float>>(solver->net()->layers()[0])->AddDatumVector(inputc._dv);
if (!solver->test_nets().empty())
/*if (!solver->test_nets().empty())
{
if (!inputc._dv_test.empty())
boost::dynamic_pointer_cast<caffe::MemoryDataLayer<float>>(solver->test_nets().at(0)->layers()[0])->AddDatumVector(inputc._dv_test);
else boost::dynamic_pointer_cast<caffe::MemoryDataLayer<float>>(solver->test_nets().at(0)->layers()[0])->AddDatumVector(inputc._dv);
}
}*/
}
if (!this->_mlmodel._weights.empty())
{
Expand Down Expand Up @@ -271,7 +271,28 @@ namespace dd
if (solver->param_.test_interval() && solver->iter_ % solver->param_.test_interval() == 0
&& (solver->iter_ > 0 || solver->param_.test_initialization()))
{
solver->TestAll();
//solver->TestAll();
/*if (_net)
{
delete _net;
_net = nullptr;
}*/
if (!_net)
{
_net = new Net<float>(this->_mlmodel._def,caffe::TEST);
}
_net->ShareTrainedLayersWith(solver->net().get());
APIData meas_out;
test(_net,ad,inputc,test_batch_size,has_mean_file,meas_out);
APIData meas_obj = meas_out.getobj("measure");
std::vector<std::string> meas_str = meas_obj.list_keys();
for (auto m: meas_str)
{
double mval = meas_obj.get(m).get<double>();
LOG(INFO) << m << "=" << mval;
this->add_meas(m,mval);
this->add_meas_per_iter(m,mval);
}
}
float loss = solver->net_->ForwardBackward(bottom_vec);
if (static_cast<int>(losses.size()) < average_loss)
Expand All @@ -287,7 +308,6 @@ namespace dd
losses[idx] = loss;
}
this->add_meas("loss",smoothed_loss);

if (solver->param_.test_interval() && solver->iter_ % solver->param_.test_interval() == 0)
{
this->add_meas_per_iter("loss",loss); // to avoid filling up with possibly millions of entries...
Expand Down Expand Up @@ -319,31 +339,35 @@ namespace dd
throw MLLibBadParamException("no deploy file in " + this->_mlmodel._repo + " for initializing the net");

// test
test(ad,inputc,test_batch_size,has_mean_file,out);
test(_net,ad,inputc,test_batch_size,has_mean_file,out);

return 0;
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy, class TMLModel>
void CaffeLib<TInputConnectorStrategy,TOutputConnectorStrategy,TMLModel>::test(const APIData &ad,
void CaffeLib<TInputConnectorStrategy,TOutputConnectorStrategy,TMLModel>::test(caffe::Net<float> *net,
const APIData &ad,
TInputConnectorStrategy &inputc,
const int &test_batch_size,
const bool &has_mean_file,
APIData &out)
{
APIData ad_res;
ad_res.add("loss",this->get_meas("loss"));
APIData ad_out = ad.getobj("parameters").getobj("output");
if (ad_out.has("measure"))
{
APIData ad_res;
float mean_loss = 0.0;
int tresults = 0;
ad_res.add("nclasses",_nclasses);
inputc.reset_dv_test();
std::vector<caffe::Datum> dv;
while(!(dv=inputc.get_dv_test(test_batch_size,has_mean_file)).empty())
{
boost::dynamic_pointer_cast<caffe::MemoryDataLayer<float>>(_net->layers()[0])->set_batch_size(dv.size());
boost::dynamic_pointer_cast<caffe::MemoryDataLayer<float>>(_net->layers()[0])->AddDatumVector(dv);
boost::dynamic_pointer_cast<caffe::MemoryDataLayer<float>>(net->layers()[0])->set_batch_size(dv.size());
boost::dynamic_pointer_cast<caffe::MemoryDataLayer<float>>(net->layers()[0])->AddDatumVector(dv);
float loss = 0.0;
std::vector<Blob<float>*> lresults = _net->ForwardPrefilled(&loss);
std::vector<Blob<float>*> lresults = net->ForwardPrefilled(&loss);
int slot = lresults.size() - 1;
int scount = lresults[slot]->count();
int scperel = scount / dv.size();
Expand All @@ -362,10 +386,11 @@ namespace dd
ad_res.add(std::to_string(tresults+j),vad);
}
tresults += dv.size();
mean_loss += loss;
}
ad_res.add("batch_size",tresults);
this->_outputc.measure(ad_res,ad_out,out);
}
this->_outputc.measure(ad_res,ad_out,out);
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy, class TMLModel>
Expand Down
3 changes: 2 additions & 1 deletion src/caffelib.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ namespace dd
* @param has_mean_file whereas testing set uses a mean file (for images)
* @param out output data object
*/
void test(const APIData &ad,
void test(caffe::Net<float> *net,
const APIData &ad,
TInputConnectorStrategy &inputc,
const int &test_batch_size,
const bool &has_mean_file,
Expand Down
10 changes: 8 additions & 2 deletions src/mllibstrategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,16 @@ namespace dd
*/
void collect_measures_history(APIData &ad)
{
APIData meas_hist;
std::lock_guard<std::mutex> lock(_meas_per_iter_mutex);
auto hit = _meas_per_iter.begin();
while(hit!=_meas_per_iter.end())
{
ad.add((*hit).first+"_hist",(*hit).second);
meas_hist.add((*hit).first+"_hist",(*hit).second);
++hit;
}
std::vector<APIData> vad = { meas_hist };
ad.add("measure_hist",vad);
}

/**
Expand Down Expand Up @@ -185,13 +188,16 @@ namespace dd
*/
void collect_measures(APIData &ad)
{
APIData meas;
std::lock_guard<std::mutex> lock(_meas_mutex);
auto hit = _meas.begin();
while(hit!=_meas.end())
{
ad.add((*hit).first,(*hit).second);
meas.add((*hit).first,(*hit).second);
++hit;
}
std::vector<APIData> vad = {meas};
ad.add("measure",vad);
}

TInputConnectorStrategy _inputc; /**< input connector strategy for channeling data in. */
Expand Down
4 changes: 2 additions & 2 deletions src/mlservice.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ namespace dd
else
{
int status = this->train(ad,out);
this->collect_measures(out);
//this->collect_measures(out);
APIData ad_params_out = ad.getobj("parameters").getobj("output");
if (ad_params_out.has("measure_hist") && ad_params_out.get("measure_hist").get<bool>())
this->collect_measures_history(out);
Expand Down Expand Up @@ -187,7 +187,7 @@ namespace dd
if (st == 0)
out.add("status","finished");
else out.add("status","unknown error");
this->collect_measures(out); // XXX: beware if there was a queue, since the job has finished, there might be a new one running.
//this->collect_measures(out); // XXX: beware if there was a queue, since the job has finished, there might be a new one running.
std::chrono::time_point<std::chrono::system_clock> trun = std::chrono::system_clock::now();
out.add("time",std::chrono::duration_cast<std::chrono::seconds>(trun-(*hit).second._tstart).count());
if (ad.has("measure_hist") && ad.get("measure_hist").get<bool>())
Expand Down
69 changes: 39 additions & 30 deletions src/outputconnectorstrategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,36 +191,45 @@ namespace dd
// measure
void measure(const APIData &ad_res, const APIData &ad_out, APIData &out)
{
std::vector<std::string> measures = ad_out.get("measure").get<std::vector<std::string>>();
bool bauc = (std::find(measures.begin(),measures.end(),"auc")!=measures.end());
bool bacc = (std::find(measures.begin(),measures.end(),"acc")!=measures.end());
bool bf1 = (std::find(measures.begin(),measures.end(),"f1")!=measures.end());
bool bmcll = (std::find(measures.begin(),measures.end(),"mcll")!=measures.end());
if (bauc) // XXX: applies two binary classification problems only
{
double mauc = auc(ad_res);
out.add("auc",mauc);
}
if (bacc)
{
double macc = acc(ad_res);
out.add("acc",macc);
}
if (bf1)
{
double f1,precision,recall,acc;
f1 = mf1(ad_res,precision,recall,acc);
out.add("f1",f1);
out.add("precision",precision);
out.add("recall",recall);
out.add("accp",acc);
//TODO: confusion matrix ?
}
if (bmcll)
{
double mmcll = mcll(ad_res);
out.add("mcll",mmcll);
}
APIData meas_out;
bool loss = ad_res.has("loss");
if (ad_out.has("measures"))
{
std::vector<std::string> measures = ad_out.get("measure").get<std::vector<std::string>>();
bool bauc = (std::find(measures.begin(),measures.end(),"auc")!=measures.end());
bool bacc = (std::find(measures.begin(),measures.end(),"acc")!=measures.end());
bool bf1 = (std::find(measures.begin(),measures.end(),"f1")!=measures.end());
bool bmcll = (std::find(measures.begin(),measures.end(),"mcll")!=measures.end());
if (bauc) // XXX: applies two binary classification problems only
{
double mauc = auc(ad_res);
meas_out.add("auc",mauc);
}
if (bacc)
{
double macc = acc(ad_res);
meas_out.add("acc",macc);
}
if (bf1)
{
double f1,precision,recall,acc;
f1 = mf1(ad_res,precision,recall,acc);
meas_out.add("f1",f1);
meas_out.add("precision",precision);
meas_out.add("recall",recall);
meas_out.add("accp",acc);
//TODO: confusion matrix ?
}
if (bmcll)
{
double mmcll = mcll(ad_res);
meas_out.add("mcll",mmcll);
}
}
if (loss)
meas_out.add("loss",ad_res.get("loss").get<double>()); // 'universal', comes from algorithm
std::vector<APIData> vad = { meas_out };
out.add("measure",vad);
}

// measure: ACC
Expand Down
Loading

0 comments on commit ec99bff

Please sign in to comment.