Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
KUMAZAKI Hiroki committed Mar 20, 2015
1 parent 3c1a09c commit bc14cd1
Show file tree
Hide file tree
Showing 28 changed files with 275 additions and 200 deletions.
4 changes: 2 additions & 2 deletions jubatus/core/classifier/arow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ namespace jubatus {
namespace core {
namespace classifier {

arow::arow(storage_ptr storage)
arow::arow()
: linear_classifier(storage) {
}

arow::arow(
const classifier_config& config,
const classifier_parameter& config,
storage_ptr storage)
: linear_classifier(storage),
config_(config) {
Expand Down
5 changes: 2 additions & 3 deletions jubatus/core/classifier/arow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ namespace classifier {

class arow : public linear_classifier {
public:
explicit arow(storage_ptr storage);
arow(const classifier_config& config, storage_ptr storage);
arow(const classifier_parameter& config);
void train(const common::sfv_t& fv, const std::string& label);
std::string name() const;
private:
Expand All @@ -38,7 +37,7 @@ class arow : public linear_classifier {
float beta,
const std::string& pos_label,
const std::string& neg_label);
classifier_config config_;
classifier_parameter config_;
};

} // namespace classifier
Expand Down
142 changes: 139 additions & 3 deletions jubatus/core/classifier/classifier_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
#define JUBATUS_CORE_CLASSIFIER_CLASSIFIER_CONFIG_HPP_

#include "jubatus/util/data/serialization.h"
#include "jubatus/util/data/optional.h"
#include "../unlearner/unlearner_config.hpp"

namespace jubatus {
namespace core {
namespace classifier {

struct classifier_config {
classifier_config()
namespace detail {
struct classifier_parameter {
classifier_parameter()
: regularization_weight(1.0f) {
}

Expand All @@ -36,6 +38,140 @@ struct classifier_config {
}
};

struct unlearner_config {
jubatus::util::data::optional<std::string> unlearner;
jubatus::util::data::optional<unlearner::unlearner_config_base>
unlearner_parameter;

template<typename Ar>
void serialize(Ar& ar) {
ar & JUBA_MEMBER(unlearner) & JUBA_MEMBER(unlearner_parameter);
}
};

struct unlearning_classifier_config
: public classifier_parameter, unlearner_config {
template<typename Ar>
void serialize(Ar& ar) {
classifier_parameter::serialize(ar);
unlearner_config::serialize(ar);
}
};

struct nearest_neighbor_classifier_config
: public unlearner_config {
std::string method;
classifier_parameter parameter;
int nearest_neighbor_num;
float local_sensitivity;

template<typename Ar>
void serialize(Ar& ar) {
ar & JUBA_MEMBER(method)
& JUBA_MEMBER(parameter)
& JUBA_MEMBER(nearest_neighbor_num)
& JUBA_MEMBER(local_sensitivity);
unlearner_config::serialize(ar);
}
};

} // namespace detail

struct classifier_config {
std::string method_;
util::data::optional<detail::unlearner_config> unlearner_conf_;
util::data::optional<unlearning_classifier_config> unlerner_classifier_conf_;
util::data::optional<nearest_neighbor_classifier_config>
nearest_neighbor_conf_;
classifier_config(const std::string& method,
const common::jsonconfig::config& param)
: method_(method) {
if (method_ == "perceptron") {
// perceptron doesn't have parameter
if (param.type() != jubatus::util::text::json::json::Null) {
unlerner_conf_ = config_cast_check<classifier_config>(param);
}
} else if (name == "PA" || name == "passive_aggressive") {
// passive_aggressive doesn't have parameter
if (param.type() != jubatus::util::text::json::json::Null) {
unlearner_conf = config_cast_check<classifier_config>(param);
}
} else if (name == "PA1" || name == "passive_aggressive_1") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
unlearning_classifier_conf_ conf
= config_cast_check<unlearning_classifier_config>(param);
unlearner = create_unlearner(conf);
res.reset(new passive_aggressive_1(conf, storage));
} else if (name == "PA2" || name == "passive_aggressive_2") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
unlearning_classifier_config conf
= config_cast_check<unlearning_classifier_config>(param);
unlearner = create_unlearner(conf);
res.reset(new passive_aggressive_2(conf, storage));
} else if (name == "CW" || name == "confidence_weighted") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
unlearning_classifier_config conf
= config_cast_check<unlearning_classifier_config>(param);
unlearner = create_unlearner(conf);
res.reset(new confidence_weighted(conf, storage));
} else if (name == "AROW" || name == "arow") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
unlearning_classifier_config conf
= config_cast_check<unlearning_classifier_config>(param);
unlearner = create_unlearner(conf);
res.reset(new arow(conf, storage));
} else if (name == "NHERD" || name == "normal_herd") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
unlearning_classifier_config conf
= config_cast_check<unlearning_classifier_config>(param);
unlearner = create_unlearner(conf);
res.reset(new normal_herd(conf, storage));
} else if (name == "NN" || name == "nearest_neighbor") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
nearest_neighbor_classifier_config conf
= config_cast_check<nearest_neighbor_classifier_config>(param);
unlearner = create_unlearner(conf);
shared_ptr<storage::column_table> table(new storage::column_table);
shared_ptr<nearest_neighbor::nearest_neighbor_base>
nearest_neighbor_engine(nearest_neighbor::create_nearest_neighbor(
conf.method, conf.parameter, table, ""));
res.reset(
new nearest_neighbor_classifier(nearest_neighbor_engine,
conf.nearest_neighbor_num,
conf.local_sensitivity));
} else {
throw JUBATUS_EXCEPTION(
common::unsupported_method("classifier(" + name + ")"));
}
}

model
};

} // namespace classifier
} // namespace core
} // namespace jubatus
Expand Down
140 changes: 23 additions & 117 deletions jubatus/core/classifier/classifier_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,56 +20,16 @@

#include "classifier.hpp"
#include "../common/exception.hpp"
#include "../common/jsonconfig.hpp"
#include "../storage/storage_base.hpp"
#include "../unlearner/unlearner_factory.hpp"
#include "../nearest_neighbor/nearest_neighbor_factory.hpp"

using jubatus::core::common::jsonconfig::config;
using jubatus::core::common::jsonconfig::config_cast_check;
using jubatus::util::lang::shared_ptr;

namespace jubatus {
namespace core {
namespace classifier {
namespace {

struct unlearner_config {
jubatus::util::data::optional<std::string> unlearner;
jubatus::util::data::optional<config> unlearner_parameter;

template<typename Ar>
void serialize(Ar& ar) {
ar & JUBA_MEMBER(unlearner) & JUBA_MEMBER(unlearner_parameter);
}
};

struct unlearning_classifier_config
: public classifier_config, unlearner_config {
template<typename Ar>
void serialize(Ar& ar) {
classifier_config::serialize(ar);
unlearner_config::serialize(ar);
}
};

struct nearest_neighbor_classifier_config
: public unlearner_config {
std::string method;
config parameter;
int nearest_neighbor_num;
float local_sensitivity;

template<typename Ar>
void serialize(Ar& ar) {
ar & JUBA_MEMBER(method)
& JUBA_MEMBER(parameter)
& JUBA_MEMBER(nearest_neighbor_num)
& JUBA_MEMBER(local_sensitivity);
unlearner_config::serialize(ar);
}
};

jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
create_unlearner(const unlearner_config& conf) {
if (conf.unlearner) {
Expand All @@ -87,97 +47,43 @@ create_unlearner(const unlearner_config& conf) {
} // namespace

shared_ptr<classifier_base> classifier_factory::create_classifier(
const std::string& name,
const common::jsonconfig::config& param,
jubatus::util::lang::shared_ptr<storage::storage_base> storage) {
const classifier_config& conf) {
jubatus::util::lang::shared_ptr<unlearner::unlearner_base> unlearner;
shared_ptr<classifier_base> res;
if (name == "perceptron") {
if (conf.unlearner_conf_) {
unlearner = create_unlearner(*conf.unlearner_conf_);
} else if (conf.unlearner_classifier_conf_) {
unlearner = create_unlearner(*conf.unlearner_classifier_conf_);
}

if (conf.method_ == "perceptron") {
// perceptron doesn't have parameter
if (param.type() != jubatus::util::text::json::json::Null) {
unlearner_config conf = config_cast_check<unlearner_config>(param);
unlearner = create_unlearner(conf);
}
res.reset(new perceptron(storage));
} else if (name == "PA" || name == "passive_aggressive") {
res.reset(new perceptron());
} else if (conf.method_ == "PA" || conf.method_ == "passive_aggressive") {
// passive_aggressive doesn't have parameter
if (param.type() != jubatus::util::text::json::json::Null) {
unlearner_config conf = config_cast_check<unlearner_config>(param);
unlearner = create_unlearner(conf);
}
res.reset(new passive_aggressive(storage));
} else if (name == "PA1" || name == "passive_aggressive_1") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
unlearning_classifier_config conf
= config_cast_check<unlearning_classifier_config>(param);
unlearner = create_unlearner(conf);
res.reset(new passive_aggressive_1(conf, storage));
} else if (name == "PA2" || name == "passive_aggressive_2") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
unlearning_classifier_config conf
= config_cast_check<unlearning_classifier_config>(param);
unlearner = create_unlearner(conf);
res.reset(new passive_aggressive_2(conf, storage));
} else if (name == "CW" || name == "confidence_weighted") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
unlearning_classifier_config conf
= config_cast_check<unlearning_classifier_config>(param);
unlearner = create_unlearner(conf);
res.reset(new confidence_weighted(conf, storage));
} else if (name == "AROW" || name == "arow") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
unlearning_classifier_config conf
= config_cast_check<unlearning_classifier_config>(param);
unlearner = create_unlearner(conf);
res.reset(new arow(conf, storage));
} else if (name == "NHERD" || name == "normal_herd") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
unlearning_classifier_config conf
= config_cast_check<unlearning_classifier_config>(param);
unlearner = create_unlearner(conf);
res.reset(new normal_herd(conf, storage));
} else if (name == "NN" || name == "nearest_neighbor") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
nearest_neighbor_classifier_config conf
= config_cast_check<nearest_neighbor_classifier_config>(param);
unlearner = create_unlearner(conf);
res.reset(new passive_aggressive());
} else if (conf.method_ == "PA1" || conf.method_ == "passive_aggressive_1") {
res.reset(new passive_aggressive_1(conf.unlearner_classifier_conf_));
} else if (conf.method_ == "PA2" || conf.method_ == "passive_aggressive_2") {
res.reset(new passive_aggressive_2(conf.unlearner_classifier_conf_));
} else if (conf.method_ == "CW" || conf.method_ == "confidence_weighted") {
res.reset(new confidence_weighted(conf.unlearner_classifier_conf_));
} else if (conf.method_ == "AROW" || conf.method_ == "arow") {
res.reset(new arow(conf.unlearner_classifier_conf_));
} else if (conf.method_ == "NHERD" || conf.method_ == "normal_herd") {
res.reset(new normal_herd(conf.unlearner_classifier_conf_));
} else if (conf.method_ == "NN" || conf.method_ == "nearest_neighbor") {
shared_ptr<storage::column_table> table(new storage::column_table);
shared_ptr<nearest_neighbor::nearest_neighbor_base>
nearest_neighbor_engine(nearest_neighbor::create_nearest_neighbor(
conf.method, conf.parameter, table, ""));
res.reset(
new nearest_neighbor_classifier(nearest_neighbor_engine,
conf.nearest_neighbor_num,
conf.local_sensitivity));
conf.nearest_neighbor_conf_));
} else {
throw JUBATUS_EXCEPTION(
common::unsupported_method("classifier(" + name + ")"));
}

if (unlearner) {
res->set_label_unlearner(unlearner);
}
Expand Down
2 changes: 1 addition & 1 deletion jubatus/core/classifier/classifier_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ INSTANTIATE_TYPED_TEST_CASE_P(cl, classifier_test, classifier_types);

TEST(classifier_config_test, regularization_weight) {
storage_ptr s(new local_storage);
classifier_config c;
classifier_parameter c;

c.regularization_weight = std::numeric_limits<float>::quiet_NaN();
ASSERT_THROW(passive_aggressive_1 p1(c, s), common::invalid_parameter);
Expand Down
9 changes: 2 additions & 7 deletions jubatus/core/classifier/confidence_weighted.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,9 @@ namespace jubatus {
namespace core {
namespace classifier {

confidence_weighted::confidence_weighted(storage_ptr storage)
: linear_classifier(storage) {
}

confidence_weighted::confidence_weighted(
const classifier_config& config,
storage_ptr storage)
: linear_classifier(storage),
const classifier_parameter& config)
: linear_classifier(),
config_(config) {

if (!(0.f < config.regularization_weight)) {
Expand Down
Loading

0 comments on commit bc14cd1

Please sign in to comment.