diff --git a/jubatus/core/anomaly/anomaly_factory.cpp b/jubatus/core/anomaly/anomaly_factory.cpp index 1fc1bc35..d2924fe9 100644 --- a/jubatus/core/anomaly/anomaly_factory.cpp +++ b/jubatus/core/anomaly/anomaly_factory.cpp @@ -24,6 +24,7 @@ #include "../common/exception.hpp" #include "../common/jsonconfig.hpp" #include "../nearest_neighbor/nearest_neighbor_factory.hpp" +#include "../unlearner/unlearner_config.hpp" #include "../unlearner/unlearner_factory.hpp" #include "../storage/column_table.hpp" #include "../recommender/recommender_factory.hpp" @@ -100,10 +101,11 @@ shared_ptr anomaly_factory::create_anomaly( << common::exception::error_message( "unlearner is set but unlearner_parameter is not found")); } + shared_ptr unl_conf( + unlearner::create_unlearner_config(*conf.unlearner, + *conf.unlearner_parameter)); jubatus::util::lang::shared_ptr unlearner( - unlearner::create_unlearner( - *conf.unlearner, - *conf.unlearner_parameter)); + unlearner::create_unlearner(unl_conf)); return shared_ptr( new light_lof(conf, id, nearest_neighbor_engine, unlearner)); } diff --git a/jubatus/core/classifier/arow.cpp b/jubatus/core/classifier/arow.cpp index 314d0a49..0a67342d 100644 --- a/jubatus/core/classifier/arow.cpp +++ b/jubatus/core/classifier/arow.cpp @@ -22,22 +22,18 @@ #include "classifier_util.hpp" #include "../common/exception.hpp" +#include "../storage/storage_base.hpp" using std::string; +using jubatus::core::storage_ptr; namespace jubatus { namespace core { namespace classifier { -arow::arow() - : linear_classifier(storage) { -} - -arow::arow( - const classifier_parameter& config, - storage_ptr storage) - : linear_classifier(storage), - config_(config) { +arow::arow(const classifier_parameter& config) + : linear_classifier(), + config_(config) { if (!(0.f < config.regularization_weight)) { throw JUBATUS_EXCEPTION( diff --git a/jubatus/core/classifier/arow.hpp b/jubatus/core/classifier/arow.hpp index 57f92177..900549f9 100644 --- a/jubatus/core/classifier/arow.hpp +++ b/jubatus/core/classifier/arow.hpp @@ -24,6 +24,7 @@ namespace jubatus { namespace core { namespace classifier { +struct classifier_parameter; class arow : public linear_classifier { public: diff --git a/jubatus/core/classifier/classifier_config.cpp b/jubatus/core/classifier/classifier_config.cpp new file mode 100644 index 00000000..484be90f --- /dev/null +++ b/jubatus/core/classifier/classifier_config.cpp @@ -0,0 +1,119 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2012 Preferred Networks and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + +#include "classifier_config.hpp" + +#include "jubatus/util/data/serialization.h" +#include "jubatus/util/data/optional.h" +#include "jubatus/util/lang/shared_ptr.h" +#include "../common/jsonconfig.hpp" +#include "../unlearner/unlearner_config.hpp" +#include "../storage/column_table.hpp" +#include "../nearest_neighbor/nearest_neighbor_base.hpp" + +using jubatus::util::lang::shared_ptr; + +namespace jubatus { +namespace core { +namespace classifier { +classifier_config::classifier_config(const std::string& method, + const common::jsonconfig::config& param) { + typedef util::lang::shared_ptr conf_ptr; + + if (method == "perceptron" || + method == "PA" || method == "passive_aggressive") { + // perceptron passive_aggressive doesn't have parameter + conf_ = conf_ptr(new classifier_parameter( + common::config::config_cast_check(param))); + } else if (name == "PA1" || name == "passive_aggressive_1") { + if (param.type() == jubatus::util::text::json::json::Null) { + throw JUBATUS_EXCEPTION( + common::config_exception() << common::exception::error_message( + "parameter block is not specified in config")); + } + + unlearning_classifier_conf_ conf + = config_cast_check(param); + unlearner = create_unlearner(conf); + res.reset(new passive_aggressive_1(conf, storage)); + } else if (name == "PA2" || name == "passive_aggressive_2") { + if (param.type() == jubatus::util::text::json::json::Null) { + throw JUBATUS_EXCEPTION( + common::config_exception() << common::exception::error_message( + "parameter block is not specified in config")); + } + unlearning_classifier_config conf + = config_cast_check(param); + unlearner = create_unlearner(conf); + res.reset(new passive_aggressive_2(conf, storage)); + } else if (name == "CW" || name == "confidence_weighted") { + if (param.type() == jubatus::util::text::json::json::Null) { + throw JUBATUS_EXCEPTION( + common::config_exception() << common::exception::error_message( + "parameter block is not specified in config")); + } + unlearning_classifier_config conf + = config_cast_check(param); + unlearner = create_unlearner(conf); + res.reset(new confidence_weighted(conf, storage)); + } else if (name == "AROW" || name == "arow") { + if (param.type() == jubatus::util::text::json::json::Null) { + throw JUBATUS_EXCEPTION( + common::config_exception() << common::exception::error_message( + "parameter block is not specified in config")); + } + unlearning_classifier_config conf + = config_cast_check(param); + unlearner = create_unlearner(conf); + res.reset(new arow(conf, storage)); + } else if (name == "NHERD" || name == "normal_herd") { + if (param.type() == jubatus::util::text::json::json::Null) { + throw JUBATUS_EXCEPTION( + common::config_exception() << common::exception::error_message( + "parameter block is not specified in config")); + + } + unlearning_classifier_config conf + = config_cast_check(param); + unlearner = create_unlearner(conf); + res.reset(new normal_herd(conf, storage)); + } else if (name == "NN" || name == "nearest_neighbor") { + if (param.type() == jubatus::util::text::json::json::Null) { + throw JUBATUS_EXCEPTION( + common::config_exception() << common::exception::error_message( + "parameter block is not specified in config")); + } + nearest_neighbor_classifier_config conf + = config_cast_check(param); + unlearner = create_unlearner(conf); + shared_ptr table(new storage::column_table); + shared_ptr + nearest_neighbor_engine(nearest_neighbor::create_nearest_neighbor( + conf.method, conf.parameter, table, "")); + res.reset( + new nearest_neighbor_classifier(nearest_neighbor_engine, + conf.nearest_neighbor_num, + conf.local_sensitivity)); + } else { + throw JUBATUS_EXCEPTION( + common::unsupported_method("classifier(" + name + ")")); + } +} +}; + +} // namespace classifier +} // namespace core +} // namespace jubatus diff --git a/jubatus/core/classifier/classifier_config.hpp b/jubatus/core/classifier/classifier_config.hpp index c3cbb80c..d48985da 100644 --- a/jubatus/core/classifier/classifier_config.hpp +++ b/jubatus/core/classifier/classifier_config.hpp @@ -19,157 +19,91 @@ #include "jubatus/util/data/serialization.h" #include "jubatus/util/data/optional.h" +#include "jubatus/util/lang/shared_ptr.h" #include "../unlearner/unlearner_config.hpp" namespace jubatus { namespace core { namespace classifier { namespace detail { -struct classifier_parameter { - classifier_parameter() - : regularization_weight(1.0f) { - } - float regularization_weight; +struct classifier_config_base { + std::string method; + virtual ~classifier_config_base() {} template void serialize(Ar& ar) { - ar & JUBA_NAMED_MEMBER("regularization_weight", regularization_weight); + ar & JUBA_MEMBER(method); } }; -struct unlearner_config { - jubatus::util::data::optional unlearner; - jubatus::util::data::optional - unlearner_parameter; +struct classifier_parameter : public classifier_config_base { + classifier_parameter() + : regularization_weight(1.0f) { + } + float regularization_weight; template void serialize(Ar& ar) { - ar & JUBA_MEMBER(unlearner) & JUBA_MEMBER(unlearner_parameter); + classifier_config_base::serialize(ar); + ar & JUBA_MEMBER(regularization_weight); } }; -struct unlearning_classifier_config - : public classifier_parameter, unlearner_config { +float get_reguralization_weight(const classifier_config_base& conf) { + const classifier_parameter* c = + dynamic_cast(&conf); + if (!c) { + throw JUBATUS_EXCEPTION( + common::config_exception() << common::exception::error_message( + "invalid classifier parameter")); + } + return c->regularization_weight; +} + +struct unlearning_classifier_config : public classifier_config_base { template void serialize(Ar& ar) { - classifier_parameter::serialize(ar); - unlearner_config::serialize(ar); + classifier_config_base::serialize(ar); + unlearner_config_->serialize(ar); + } + unlearning_classifier_config(const std::string& method, + const common::jsonconfig::config& param) + : classifier_config_base(method) { + param[""] + if (param.type() == jubatus::util::text::json::json::Null) { + throw JUBATUS_EXCEPTION( + common::config_exception() << common::exception::error_message( + "parameter block is not specified in config")); + } } + util::lang::shared_ptr unlearner_config_; }; -struct nearest_neighbor_classifier_config - : public unlearner_config { +struct nearest_neighbor_classifier_config : public classifier_config_base { std::string method; classifier_parameter parameter; int nearest_neighbor_num; float local_sensitivity; + util::lang::shared_ptr unlearner_config_; template void serialize(Ar& ar) { + classifier_config_base::serialize(ar); ar & JUBA_MEMBER(method) & JUBA_MEMBER(parameter) & JUBA_MEMBER(nearest_neighbor_num) & JUBA_MEMBER(local_sensitivity); - unlearner_config::serialize(ar); + unlearner_config_->serialize(ar); } }; } // namespace detail struct classifier_config { - std::string method_; - util::data::optional unlearner_conf_; - util::data::optional unlerner_classifier_conf_; - util::data::optional - nearest_neighbor_conf_; + util::lang::shared_ptr conf_; classifier_config(const std::string& method, - const common::jsonconfig::config& param) - : method_(method) { - if (method_ == "perceptron") { - // perceptron doesn't have parameter - if (param.type() != jubatus::util::text::json::json::Null) { - unlerner_conf_ = config_cast_check(param); - } - } else if (name == "PA" || name == "passive_aggressive") { - // passive_aggressive doesn't have parameter - if (param.type() != jubatus::util::text::json::json::Null) { - unlearner_conf = config_cast_check(param); - } - } else if (name == "PA1" || name == "passive_aggressive_1") { - if (param.type() == jubatus::util::text::json::json::Null) { - throw JUBATUS_EXCEPTION( - common::config_exception() << common::exception::error_message( - "parameter block is not specified in config")); - } - unlearning_classifier_conf_ conf - = config_cast_check(param); - unlearner = create_unlearner(conf); - res.reset(new passive_aggressive_1(conf, storage)); - } else if (name == "PA2" || name == "passive_aggressive_2") { - if (param.type() == jubatus::util::text::json::json::Null) { - throw JUBATUS_EXCEPTION( - common::config_exception() << common::exception::error_message( - "parameter block is not specified in config")); - } - unlearning_classifier_config conf - = config_cast_check(param); - unlearner = create_unlearner(conf); - res.reset(new passive_aggressive_2(conf, storage)); - } else if (name == "CW" || name == "confidence_weighted") { - if (param.type() == jubatus::util::text::json::json::Null) { - throw JUBATUS_EXCEPTION( - common::config_exception() << common::exception::error_message( - "parameter block is not specified in config")); - } - unlearning_classifier_config conf - = config_cast_check(param); - unlearner = create_unlearner(conf); - res.reset(new confidence_weighted(conf, storage)); - } else if (name == "AROW" || name == "arow") { - if (param.type() == jubatus::util::text::json::json::Null) { - throw JUBATUS_EXCEPTION( - common::config_exception() << common::exception::error_message( - "parameter block is not specified in config")); - } - unlearning_classifier_config conf - = config_cast_check(param); - unlearner = create_unlearner(conf); - res.reset(new arow(conf, storage)); - } else if (name == "NHERD" || name == "normal_herd") { - if (param.type() == jubatus::util::text::json::json::Null) { - throw JUBATUS_EXCEPTION( - common::config_exception() << common::exception::error_message( - "parameter block is not specified in config")); - } - unlearning_classifier_config conf - = config_cast_check(param); - unlearner = create_unlearner(conf); - res.reset(new normal_herd(conf, storage)); - } else if (name == "NN" || name == "nearest_neighbor") { - if (param.type() == jubatus::util::text::json::json::Null) { - throw JUBATUS_EXCEPTION( - common::config_exception() << common::exception::error_message( - "parameter block is not specified in config")); - } - nearest_neighbor_classifier_config conf - = config_cast_check(param); - unlearner = create_unlearner(conf); - shared_ptr table(new storage::column_table); - shared_ptr - nearest_neighbor_engine(nearest_neighbor::create_nearest_neighbor( - conf.method, conf.parameter, table, "")); - res.reset( - new nearest_neighbor_classifier(nearest_neighbor_engine, - conf.nearest_neighbor_num, - conf.local_sensitivity)); - } else { - throw JUBATUS_EXCEPTION( - common::unsupported_method("classifier(" + name + ")")); - } - } - - model + const common::jsonconfig::config& param); }; } // namespace classifier diff --git a/jubatus/core/classifier/classifier_factory.cpp b/jubatus/core/classifier/classifier_factory.cpp index b4f5cb54..cb279068 100644 --- a/jubatus/core/classifier/classifier_factory.cpp +++ b/jubatus/core/classifier/classifier_factory.cpp @@ -22,6 +22,7 @@ #include "../common/exception.hpp" #include "../storage/storage_base.hpp" #include "../unlearner/unlearner_factory.hpp" +#include "../unlearner/unlearner_config.hpp" #include "../nearest_neighbor/nearest_neighbor_factory.hpp" using jubatus::util::lang::shared_ptr; @@ -37,13 +38,14 @@ create_unlearner(const unlearner_config& conf) { throw JUBATUS_EXCEPTION(common::exception::runtime_error( "Unlearner is set but unlearner_parameter is not found")); } + shared_ptr uconf = + unlearner::create_unlearner_config(conf); return unlearner::create_unlearner( *conf.unlearner, *conf.unlearner_parameter); } else { return jubatus::util::lang::shared_ptr(); } } - } // namespace shared_ptr classifier_factory::create_classifier( diff --git a/jubatus/core/classifier/confidence_weighted.cpp b/jubatus/core/classifier/confidence_weighted.cpp index 60f29940..453d7ed6 100644 --- a/jubatus/core/classifier/confidence_weighted.cpp +++ b/jubatus/core/classifier/confidence_weighted.cpp @@ -22,6 +22,7 @@ #include "classifier_util.hpp" #include "../common/exception.hpp" +#include "../storage/local_storage_mixture.hpp" using std::string; @@ -31,9 +32,8 @@ namespace classifier { confidence_weighted::confidence_weighted( const classifier_parameter& config) - : linear_classifier(), + : linear_classifier() { config_(config) { - if (!(0.f < config.regularization_weight)) { throw JUBATUS_EXCEPTION( common::invalid_parameter("0.0 < regularization_weight")); diff --git a/jubatus/core/classifier/confidence_weighted.hpp b/jubatus/core/classifier/confidence_weighted.hpp index b50b9f2e..5c655470 100644 --- a/jubatus/core/classifier/confidence_weighted.hpp +++ b/jubatus/core/classifier/confidence_weighted.hpp @@ -24,12 +24,12 @@ namespace jubatus { namespace core { namespace classifier { +struct classifier_parameter; class confidence_weighted : public linear_classifier { public: confidence_weighted( - const classifier_parameter& config, - storage_ptr storage); + const classifier_parameter& config); void train(const common::sfv_t& fv, const std::string& label); std::string name() const; private: diff --git a/jubatus/core/classifier/linear_classifier.cpp b/jubatus/core/classifier/linear_classifier.cpp index 433d0088..cb251ed8 100644 --- a/jubatus/core/classifier/linear_classifier.cpp +++ b/jubatus/core/classifier/linear_classifier.cpp @@ -28,6 +28,8 @@ #include "../common/exception.hpp" #include "classifier_util.hpp" +#include "../storage/storage_base.hpp" +#include "../storage/local_storage_mixture.hpp" using std::string; using std::vector; @@ -39,7 +41,8 @@ namespace core { namespace classifier { linear_classifier::linear_classifier() - : storage_(storage), mixable_storage_(storage_) { + : storage_(new storage::local_storage_mixture()), + mixable_storage_(storage_) { } linear_classifier::~linear_classifier() { @@ -210,9 +213,11 @@ float linear_classifier::squared_norm(const common::sfv_t& fv) { void linear_classifier::pack(framework::packer& pk) const { storage_->pack(pk); } + void linear_classifier::unpack(msgpack::object o) { storage_->unpack(o); } + void linear_classifier::export_model(framework::packer& pk) const { // TODO } diff --git a/jubatus/core/classifier/linear_classifier.hpp b/jubatus/core/classifier/linear_classifier.hpp index a5a1e7f0..bb31e5ec 100644 --- a/jubatus/core/classifier/linear_classifier.hpp +++ b/jubatus/core/classifier/linear_classifier.hpp @@ -40,6 +40,7 @@ class linear_classifier : public classifier_base { virtual ~linear_classifier(); virtual void train(const common::sfv_t& fv, const std::string& label) = 0; + linear_classifier(); void set_label_unlearner( jubatus::util::lang::shared_ptr label_unlearner); diff --git a/jubatus/core/classifier/normal_herd.cpp b/jubatus/core/classifier/normal_herd.cpp index 09ed98b9..b599a133 100644 --- a/jubatus/core/classifier/normal_herd.cpp +++ b/jubatus/core/classifier/normal_herd.cpp @@ -29,16 +29,9 @@ namespace jubatus { namespace core { namespace classifier { -normal_herd::normal_herd() - : linear_classifier(storage) { - config_.regularization_weight = 0.1f; -} - normal_herd::normal_herd( - const classifier_parameter& config, - storage_ptr storage) - : linear_classifier(storage), - config_(config) { + const classifier_parameter& config) + : config_(config) { if (!(0.f < config.regularization_weight)) { throw JUBATUS_EXCEPTION( diff --git a/jubatus/core/classifier/normal_herd.hpp b/jubatus/core/classifier/normal_herd.hpp index 83e1bf3f..fa942f12 100644 --- a/jubatus/core/classifier/normal_herd.hpp +++ b/jubatus/core/classifier/normal_herd.hpp @@ -24,11 +24,11 @@ namespace jubatus { namespace core { namespace classifier { +struct classifier_parameter; class normal_herd : public linear_classifier { public: - normal_herd( - const classifier_parameter& config); + normal_herd(const classifier_parameter& config); void train(const common::sfv_t& fv, const std::string& label); std::string name() const; private: diff --git a/jubatus/core/classifier/passive_aggressive.cpp b/jubatus/core/classifier/passive_aggressive.cpp index 3b4a0f1f..69239eeb 100644 --- a/jubatus/core/classifier/passive_aggressive.cpp +++ b/jubatus/core/classifier/passive_aggressive.cpp @@ -17,6 +17,7 @@ #include "passive_aggressive.hpp" #include +#include "../storage/local_storage_mixture.hpp" using std::string; @@ -25,7 +26,7 @@ namespace core { namespace classifier { passive_aggressive::passive_aggressive() - : linear_classifier(storage) { + : linear_classifier(storage::storage_ptr(new local_storage_mixture())) { } void passive_aggressive::train(const common::sfv_t& sfv, const string& label) { diff --git a/jubatus/core/classifier/passive_aggressive_1.cpp b/jubatus/core/classifier/passive_aggressive_1.cpp index 05356ced..10b32d3c 100644 --- a/jubatus/core/classifier/passive_aggressive_1.cpp +++ b/jubatus/core/classifier/passive_aggressive_1.cpp @@ -20,6 +20,7 @@ #include #include "../common/exception.hpp" +#include "../storage/local_storage_mixture.hpp" using std::string; using std::min; @@ -28,16 +29,11 @@ namespace jubatus { namespace core { namespace classifier { -passive_aggressive_1::passive_aggressive_1() - : linear_classifier(storage) { -} - passive_aggressive_1::passive_aggressive_1( const classifier_parameter& config, storage_ptr storage) - : linear_classifier(storage), + : linear_classifier(), config_(config) { - if (!(0.f < config.regularization_weight)) { throw JUBATUS_EXCEPTION( common::invalid_parameter("0.0 < regularization_weight")); diff --git a/jubatus/core/classifier/passive_aggressive_1.hpp b/jubatus/core/classifier/passive_aggressive_1.hpp index e3103b93..1b1929b7 100644 --- a/jubatus/core/classifier/passive_aggressive_1.hpp +++ b/jubatus/core/classifier/passive_aggressive_1.hpp @@ -18,16 +18,15 @@ #define JUBATUS_CORE_CLASSIFIER_PASSIVE_AGGRESSIVE_1_HPP_ #include - #include "linear_classifier.hpp" namespace jubatus { namespace core { namespace classifier { +struct classifier_parameter; class passive_aggressive_1 : public linear_classifier { public: - explicit passive_aggressive_1(); passive_aggressive_1( const classifier_parameter& config, storage_ptr storage); diff --git a/jubatus/core/classifier/passive_aggressive_2.cpp b/jubatus/core/classifier/passive_aggressive_2.cpp index 01d135b8..0e8fdc01 100644 --- a/jubatus/core/classifier/passive_aggressive_2.cpp +++ b/jubatus/core/classifier/passive_aggressive_2.cpp @@ -20,6 +20,7 @@ #include #include "../common/exception.hpp" +#include "../storage/local_storage_mixture.hpp" using std::string; @@ -27,17 +28,11 @@ namespace jubatus { namespace core { namespace classifier { -passive_aggressive_2::passive_aggressive_2() - : linear_classifier(storage) { -} - passive_aggressive_2::passive_aggressive_2( - const classifier_parameter& config, - storage_ptr storage) - : linear_classifier(storage), - config_(config) { + const classifier_config_base& config) + : regularization_weight_(detail::get_reguralization_weight(config)) { - if (!(0.f < config.regularization_weight)) { + if (!(0.f < regularization_weight_)) { throw JUBATUS_EXCEPTION( common::invalid_parameter("0.0 < regularization_weight")); } @@ -62,7 +57,7 @@ void passive_aggressive_2::train(const common::sfv_t& sfv, } update_weight( sfv, - loss / (2 * sfv_norm + 1 / (2 * config_.regularization_weight)), + loss / (2 * sfv_norm + 1 / (2 * regularization_weight_)), label, incorrect_label); touch(label); diff --git a/jubatus/core/classifier/passive_aggressive_2.hpp b/jubatus/core/classifier/passive_aggressive_2.hpp index d7dc7dbb..1b4f1821 100644 --- a/jubatus/core/classifier/passive_aggressive_2.hpp +++ b/jubatus/core/classifier/passive_aggressive_2.hpp @@ -18,8 +18,8 @@ #define JUBATUS_CORE_CLASSIFIER_PASSIVE_AGGRESSIVE_2_HPP_ #include - #include "linear_classifier.hpp" +#include "classifier_config.hpp" namespace jubatus { namespace core { @@ -27,15 +27,13 @@ namespace classifier { class passive_aggressive_2 : public linear_classifier { public: - explicit passive_aggressive_2(); passive_aggressive_2( - const classifier_parameter& config, - storage_ptr storage); + const detail::classifier_config_base& config); void train(const common::sfv_t& sfv, const std::string& label); std::string name() const; private: - classifier_parameter config_; + float regularization_weight_; }; } // namespace classifier diff --git a/jubatus/core/classifier/perceptron.cpp b/jubatus/core/classifier/perceptron.cpp index 7f905d16..ed329f38 100644 --- a/jubatus/core/classifier/perceptron.cpp +++ b/jubatus/core/classifier/perceptron.cpp @@ -17,6 +17,7 @@ #include "perceptron.hpp" #include +#include "../storage/local_storage_mixture.hpp" using std::string; @@ -25,7 +26,7 @@ namespace core { namespace classifier { perceptron::perceptron() - : linear_classifier(storage) { + : linear_classifier() { } void perceptron::train(const common::sfv_t& sfv, const std::string& label) { diff --git a/jubatus/core/classifier/wscript b/jubatus/core/classifier/wscript index aea91d09..c7479293 100644 --- a/jubatus/core/classifier/wscript +++ b/jubatus/core/classifier/wscript @@ -15,6 +15,7 @@ def build(bld): 'normal_herd.cpp', 'nearest_neighbor_classifier.cpp', 'classifier_factory.cpp', + 'classifier_config.cpp', ] headers = [ 'classifier_base.hpp', diff --git a/jubatus/core/driver/classifier.hpp b/jubatus/core/driver/classifier.hpp index 400d955b..20a04f34 100644 --- a/jubatus/core/driver/classifier.hpp +++ b/jubatus/core/driver/classifier.hpp @@ -23,8 +23,10 @@ #include "jubatus/util/lang/shared_ptr.h" #include "../common/byte_buffer.hpp" #include "../classifier/classifier_type.hpp" +#include "../classifier/classifier_config.hpp" #include "../framework/mixable.hpp" #include "../fv_converter/mixable_weight_manager.hpp" +#include "../fv_converter/converter_config.hpp" #include "driver.hpp" namespace jubatus { @@ -40,7 +42,7 @@ namespace driver { struct classifier_driver_config { std::string method; - jubatus::util::data::optional parameter; + jubatus::util::data::optional parameter; core::fv_converter::converter_config converter; template @@ -49,7 +51,6 @@ struct classifier_driver_config { } }; - class classifier : public driver_base { public: typedef core::classifier::classifier_base classifier_base; diff --git a/jubatus/core/framework/linear_function_mixer.hpp b/jubatus/core/framework/linear_function_mixer.hpp index 9ad84d51..e9347f5b 100644 --- a/jubatus/core/framework/linear_function_mixer.hpp +++ b/jubatus/core/framework/linear_function_mixer.hpp @@ -51,9 +51,6 @@ class linear_function_mixer : public linear_mixable { model_ptr get_model() const { return model_; } - void set_model(model_ptr model) { - model_ = model; - } void mix(const diffv& lhs, diffv& mixed) const; void get_diff(diffv&) const; diff --git a/jubatus/core/recommender/recommender_factory.cpp b/jubatus/core/recommender/recommender_factory.cpp index 63e1ea97..04feb571 100644 --- a/jubatus/core/recommender/recommender_factory.cpp +++ b/jubatus/core/recommender/recommender_factory.cpp @@ -82,9 +82,11 @@ shared_ptr recommender_factory::create_recommender( common::config_exception() << common::exception::error_message( "unlearner is set but unlearner_parameter is not found")); } - shared_ptr unl(unlearner::create_unlearner( - *conf.unlearner, common::jsonconfig::config( - *conf.unlearner_parameter))); + shared_ptr uconf = + unlearner::create_unlearner_config(*conf.unlearner, + *conf.unlearner_parameter); + shared_ptr unl( + unlearner::create_unlearner(uconf)); return shared_ptr( new nearest_neighbor_recommender(nearest_neighbor_engine, unl)); } diff --git a/jubatus/core/unlearner/lru_unlearner.cpp b/jubatus/core/unlearner/lru_unlearner.cpp index 39488abc..c8948f08 100644 --- a/jubatus/core/unlearner/lru_unlearner.cpp +++ b/jubatus/core/unlearner/lru_unlearner.cpp @@ -32,22 +32,23 @@ namespace jubatus { namespace core { namespace unlearner { namespace { -const lru_unlearner_config& as_lru_unlearner_config(const unlearner_config_base& conf) { - return dynamic_cast(conf); +const lru_unlearner::lru_unlearner_config& +as_lru_unlearner_config(const unlearner_config_base& conf) { + return dynamic_cast(conf); } } // namespace lru_unlearner::lru_unlearner(const unlearner_config_base& conf) : max_size_(as_lru_unlearner_config(conf).max_size) { const lru_unlearner_config& lconfig = as_lru_unlearner_config(conf); - if (lconf.max_size <= 0) { + if (lconfig.max_size <= 0) { throw JUBATUS_EXCEPTION( common::config_exception() << common::exception::error_message( "max_size must be a positive integer")); } - entry_map_.reserve(max_size_); + entry_map_.reserve(lconfig.max_size_); - if (lconf.sticky_pattern) { + if (lconfig.sticky_pattern) { key_matcher_factory f; sticky_matcher_ = f.create_matcher(*lconf.sticky_pattern); } diff --git a/jubatus/core/unlearner/lru_unlearner.hpp b/jubatus/core/unlearner/lru_unlearner.hpp index c1054dbf..3844e7f2 100644 --- a/jubatus/core/unlearner/lru_unlearner.hpp +++ b/jubatus/core/unlearner/lru_unlearner.hpp @@ -34,13 +34,12 @@ namespace jubatus { namespace core { namespace fv_converter { class key_matcher; -} +} // namespace fv_converter namespace unlearner { // Unlearner based on Least Recently Used algorithm. class lru_unlearner : public unlearner_base { public: - struct lru_unlearner_config : public unlearner_config_base { int32_t max_size; jubatus::util::data::optional sticky_pattern; diff --git a/jubatus/core/unlearner/random_unlearner.cpp b/jubatus/core/unlearner/random_unlearner.cpp index e30960e7..cd8ad154 100644 --- a/jubatus/core/unlearner/random_unlearner.cpp +++ b/jubatus/core/unlearner/random_unlearner.cpp @@ -24,14 +24,15 @@ namespace jubatus { namespace core { namespace unlearner { -const random_unlearner_config& as_random_unlearner_config(const unlearner_config_base& orig) { - return dynamic_cast(orig); +const random_unlearner::random_unlearner_config& +as_random_unlearner_config(const unlearner_config_base& orig) { + return dynamic_cast(orig); } random_unlearner::random_unlearner(const unlearner_config_base& conf) : max_size_(as_random_unlearner_config(conf).max_size) { - const random_unlearner& rconf = as_random_unlearner_config(conf); - if (rconf->max_size <= 0) { + const random_unlearner_config& rconf = as_random_unlearner_config(conf); + if (rconf.max_size <= 0) { throw JUBATUS_EXCEPTION( common::config_exception() << common::exception::error_message( "max_size must be a positive integer")); diff --git a/jubatus/core/unlearner/unlearner_config.hpp b/jubatus/core/unlearner/unlearner_config.hpp index 4f6be342..51516472 100644 --- a/jubatus/core/unlearner/unlearner_config.hpp +++ b/jubatus/core/unlearner/unlearner_config.hpp @@ -20,7 +20,7 @@ #include #include "../common/jsonconfig.hpp" #include "jubatus/util/lang/shared_ptr.h" -#include "jubatus/util/data/optional.h" +#include "jubatus/util/data/serialization.h" namespace jubatus { namespace core { @@ -29,8 +29,17 @@ namespace unlearner { struct unlearner_config_base { std::string name; virtual ~unlearner_config_base() = 0; + + template + void serialize(Ar& ar) { + ar & JUBA_NAMED_MEMBER("name", name); + } }; +util::lang::shared_ptr +create_unlearner_config(const std::string, + const common::jsonconfig::config& param); + } // namespace unlearner } // namespace core } // namespace jubatus diff --git a/jubatus/core/unlearner/unlearner_factory.cpp b/jubatus/core/unlearner/unlearner_factory.cpp index 0d969db8..6354fd6d 100644 --- a/jubatus/core/unlearner/unlearner_factory.cpp +++ b/jubatus/core/unlearner/unlearner_factory.cpp @@ -29,19 +29,30 @@ namespace core { namespace unlearner { shared_ptr create_unlearner( - const std::string& name, - const common::jsonconfig::config& config) { - if (name == "lru") { - return shared_ptr( - new lru_unlearner(common::jsonconfig::config_cast_check< - lru_unlearner::config>(config))); - } else if (name == "random") { - return shared_ptr( - new random_unlearner(common::jsonconfig::config_cast_check< - random_unlearner::config>(config))); + const shared_ptr conf) { + if (conf->name == "lru") { + lru_unlearner::lru_unlearner_config* lconf = + dynamic_cast(conf.get()); + if (lconf) { + return shared_ptr( + new lru_unlearner(*lconf)); + } else { + throw JUBATUS_EXCEPTION(common::unsupported_method( + "invaild lru unlearner config")); + } + } else if (conf->name == "random") { + random_unlearner::random_unlearner_config* rconf = + dynamic_cast(conf.get()); + if (rconf) { + return shared_ptr( + new random_unlearner(*rconf)); + } else { + throw JUBATUS_EXCEPTION(common::unsupported_method( + "invaild random unlearner config")); + } } else { throw JUBATUS_EXCEPTION(common::unsupported_method( - "unlearner(" + name + ')')); + "unlearner(" + conf->name + ')')); } }