Skip to content

Commit

Permalink
replaced loss with train_loss in Caffe output
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Mar 9, 2015
1 parent 8e43dfa commit bb3e145
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,11 @@ namespace dd
smoothed_loss += (loss - losses[idx]) / average_loss;
losses[idx] = loss;
}
this->add_meas("loss",smoothed_loss);
this->add_meas("train_loss",smoothed_loss);
if (solver->param_.test_interval() && solver->iter_ % solver->param_.test_interval() == 0)
{
this->add_meas_per_iter("loss",loss); // to avoid filling up with possibly millions of entries...
LOG(INFO) << "loss=" << this->get_meas("loss");
this->add_meas_per_iter("train_loss",loss); // to avoid filling up with possibly millions of entries...
//LOG(INFO) << "loss=" << this->get_meas("loss");
}

solver->ComputeUpdateValue();
Expand Down Expand Up @@ -353,7 +353,7 @@ namespace dd
APIData &out)
{
APIData ad_res;
ad_res.add("loss",this->get_meas("loss"));
ad_res.add("train_loss",this->get_meas("train_loss"));
APIData ad_out = ad.getobj("parameters").getobj("output");
if (ad_out.has("measure"))
{
Expand Down Expand Up @@ -389,6 +389,7 @@ namespace dd
mean_loss += loss;
}
ad_res.add("batch_size",tresults);
//ad_res.add("loss",mean_loss / static_cast<double>(tresults)); // XXX: Caffe ForwardPrefilled call above return loss = 0.0
}
this->_outputc.measure(ad_res,ad_out,out);
}
Expand Down
3 changes: 3 additions & 0 deletions src/outputconnectorstrategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ namespace dd
void measure(const APIData &ad_res, const APIData &ad_out, APIData &out)
{
APIData meas_out;
bool tloss = ad_res.has("train_loss");
bool loss = ad_res.has("loss");
if (ad_out.has("measure"))
{
Expand Down Expand Up @@ -228,6 +229,8 @@ namespace dd
}
if (loss)
meas_out.add("loss",ad_res.get("loss").get<double>()); // 'universal', comes from algorithm
if (tloss)
meas_out.add("train_loss",ad_res.get("train_loss").get<double>());
std::vector<APIData> vad = { meas_out };
out.add("measure",vad);
}
Expand Down

0 comments on commit bb3e145

Please sign in to comment.