-
Notifications
You must be signed in to change notification settings - Fork 94
/
Copy pathlearner.cc
37 lines (33 loc) · 939 Bytes
/
learner.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
/**
* Copyright (c) 2015 by Contributors
*/
#include "difacto/learner.h"
#include "./sgd/sgd_param.h"
#include "./sgd/sgd_learner.h"
#include "./bcd/bcd_param.h"
#include "./bcd/bcd_learner.h"
#include "./lbfgs/lbfgs_learner.h"
namespace difacto {
DMLC_REGISTER_PARAMETER(SGDLearnerParam);
DMLC_REGISTER_PARAMETER(BCDLearnerParam);
Learner* Learner::Create(const std::string& type) {
if (type == "sgd") {
return new SGDLearner();
} else if (type == "bcd") {
return new BCDLearner();
} else if (type == "lbfgs") {
return new LBFGSLearner();
} else {
LOG(FATAL) << "unknown learner type: " << type;
}
return nullptr;
}
KWArgs Learner::Init(const KWArgs& kwargs) {
// init job tracker
tracker_ = Tracker::Create();
auto remain = tracker_->Init(kwargs);
using namespace std::placeholders;
tracker_->SetExecutor(std::bind(&Learner::Process, this, _1, _2));
return remain;
}
} // namespace difacto