diff --git a/jubatus/core/classifier/classifier_base.hpp b/jubatus/core/classifier/classifier_base.hpp index 826a5bb4..88936cd8 100644 --- a/jubatus/core/classifier/classifier_base.hpp +++ b/jubatus/core/classifier/classifier_base.hpp @@ -59,6 +59,8 @@ class classifier_base { virtual void pack(framework::packer& pk) const = 0; virtual void unpack(msgpack::object o) = 0; + virtual void export_model(framework::packer& pk) const = 0; + virtual void import_model(msgpack::object o) = 0; virtual void clear() = 0; virtual framework::mixable* get_mixable() = 0; diff --git a/jubatus/core/classifier/linear_classifier.cpp b/jubatus/core/classifier/linear_classifier.cpp index ea4723f0..ed3b5e28 100644 --- a/jubatus/core/classifier/linear_classifier.cpp +++ b/jubatus/core/classifier/linear_classifier.cpp @@ -207,6 +207,13 @@ void linear_classifier::pack(framework::packer& pk) const { void linear_classifier::unpack(msgpack::object o) { storage_->unpack(o); } +void linear_classifier::export_model(framework::packer& pk) const { + // TODO +} + +void linear_classifier::import_model(msgpack::object o) { + // TODO +} framework::mixable* linear_classifier::get_mixable() { return &mixable_storage_; diff --git a/jubatus/core/classifier/linear_classifier.hpp b/jubatus/core/classifier/linear_classifier.hpp index 020ea160..b704f198 100644 --- a/jubatus/core/classifier/linear_classifier.hpp +++ b/jubatus/core/classifier/linear_classifier.hpp @@ -69,6 +69,8 @@ class linear_classifier : public classifier_base { void pack(framework::packer& pk) const; void unpack(msgpack::object o); + void export_model(framework::packer& pk) const; + void import_model(msgpack::object o); framework::mixable* get_mixable(); diff --git a/jubatus/core/classifier/nearest_neighbor_classifier.cpp b/jubatus/core/classifier/nearest_neighbor_classifier.cpp index dc4cdcdf..d35d8fac 100644 --- a/jubatus/core/classifier/nearest_neighbor_classifier.cpp +++ b/jubatus/core/classifier/nearest_neighbor_classifier.cpp @@ -225,6 +225,14 @@ void nearest_neighbor_classifier::unpack(msgpack::object o) { } } +void nearest_neighbor_classifier::export_model(framework::packer& pk) const { + // TODO +} +void nearest_neighbor_classifier::import_model(msgpack::object o) { + // TODO +} + + framework::mixable* nearest_neighbor_classifier::get_mixable() { return nearest_neighbor_engine_->get_mixable(); } diff --git a/jubatus/core/classifier/nearest_neighbor_classifier.hpp b/jubatus/core/classifier/nearest_neighbor_classifier.hpp index 4f47ff91..37893775 100644 --- a/jubatus/core/classifier/nearest_neighbor_classifier.hpp +++ b/jubatus/core/classifier/nearest_neighbor_classifier.hpp @@ -63,6 +63,8 @@ class nearest_neighbor_classifier : public classifier_base { void pack(framework::packer& pk) const; void unpack(msgpack::object o); + void export_model(framework::packer& pk) const; + void import_model(msgpack::object o); framework::mixable* get_mixable(); diff --git a/jubatus/core/common/byte_buffer.hpp b/jubatus/core/common/byte_buffer.hpp index 9becfbf7..2a4fe4be 100644 --- a/jubatus/core/common/byte_buffer.hpp +++ b/jubatus/core/common/byte_buffer.hpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "jubatus/util/lang/shared_ptr.h" @@ -41,6 +42,12 @@ class byte_buffer { buf_.reset(new std::vector(first, first+size)); } + void write(const char* buf, unsigned int len) { + const size_t old_tail = buf_->size(); + buf_->resize(buf_->size() + len); + std::memcpy(buf_->data() + old_tail, buf, len); + } + // following member functions are implicily defined: // byte_buffer(const byte_buffer& b) = default; // byte_buffer& operator=(const byte_buffer& b) = default; diff --git a/jubatus/core/common/export_model.hpp b/jubatus/core/common/export_model.hpp new file mode 100644 index 00000000..a934fd43 --- /dev/null +++ b/jubatus/core/common/export_model.hpp @@ -0,0 +1,28 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2014 Preferred Infrastructure and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + +#ifndef JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP__ +#define JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP__ + +#include +#include "../framework/packer.hpp" + +#define JUBATUS_EXPORT_MODEL(...) \ + void export_model(framework::packer& pk) const { \ + msgpack::type::make_define(__VA_ARGS__).msgpack_pack(pk); \ + } + +#endif // JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP_ diff --git a/jubatus/core/driver/classifier.cpp b/jubatus/core/driver/classifier.cpp index b4293ca0..78895429 100644 --- a/jubatus/core/driver/classifier.cpp +++ b/jubatus/core/driver/classifier.cpp @@ -7,12 +7,12 @@ // // This library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public // License along with this library; if not, write to the Free Software -// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA #include "classifier.hpp" @@ -24,6 +24,9 @@ #include "../classifier/classifier_factory.hpp" #include "../classifier/classifier_base.hpp" #include "../common/vector_util.hpp" +#include "../common/byte_buffer.hpp" +#include "../framework/stream_writer.hpp" +#include "../framework/packer.hpp" #include "../fv_converter/datum.hpp" #include "../fv_converter/datum_to_fv_converter.hpp" #include "../fv_converter/converter_config.hpp" @@ -31,6 +34,7 @@ using std::string; using std::vector; +using std::make_pair; using jubatus::util::lang::shared_ptr; using jubatus::core::fv_converter::weight_manager; using jubatus::core::fv_converter::mixable_weight_manager; @@ -55,40 +59,40 @@ classifier::~classifier() { } void classifier::train(const string& label, const fv_converter::datum& data) { - common::sfv_t v; - converter_->convert_and_update_weight(data, v); - common::sort_and_merge(v); - classifier_->train(v, label); + common::sfv_t v; + converter_->convert_and_update_weight(data, v); + common::sort_and_merge(v); + classifier_->train(v, label); } jubatus::core::classifier::classify_result classifier::classify( - const fv_converter::datum& data) const { - common::sfv_t v; - converter_->convert(data, v); + const fv_converter::datum& data) const { + common::sfv_t v; + converter_->convert(data, v); - jubatus::core::classifier::classify_result scores; - classifier_->classify_with_scores(v, scores); - return scores; + jubatus::core::classifier::classify_result scores; + classifier_->classify_with_scores(v, scores); + return scores; } void classifier::get_status(std::map& status) const { - classifier_->get_status(status); + classifier_->get_status(status); } bool classifier::delete_label(const std::string& label) { - return classifier_->delete_label(label); + return classifier_->delete_label(label); } void classifier::clear() { - classifier_->clear(); - converter_->clear_weights(); + classifier_->clear(); + converter_->clear_weights(); } std::vector classifier::get_labels() const { - return classifier_->get_labels(); + return classifier_->get_labels(); } bool classifier::set_label(const std::string& label) { - return classifier_->set_label(label); + return classifier_->set_label(label); } void classifier::pack(framework::packer& pk) const { @@ -109,6 +113,38 @@ void classifier::unpack(msgpack::object o) { wm_.get_model()->unpack(o.via.array.ptr[1]); } -} // namespace driver -} // namespace core -} // namespace jubatus +struct versioned_model { + // this struct is version compatible + std::string version; + common::byte_buffer buffer; + versioned_model(const std::string& v, const common::byte_buffer& b) + : version(v), buffer(b) { + } + MSGPACK_DEFINE(version, buffer); +}; + +void classifier::import_model(common::byte_buffer& from) const { + // not implemented yet +} + +common::byte_buffer classifier::export_model() const { + common::byte_buffer model; + framework::stream_writer model_writer(model); + core::framework::jubatus_packer jp(model_writer); + core::framework::packer pk(jp); +// msgpack::packer pk(model); + classifier_->export_model(pk); + wm_->export_model(pk); + { + common::byte_buffer ret; + msgpack::pack(ret, + versioned_model(JUBATUS_CORE_VERSION, + model)); + return ret; + } +} + + +} // namespace driver +} // namespace core +} // namespace jubatus diff --git a/jubatus/core/driver/classifier.hpp b/jubatus/core/driver/classifier.hpp index 34e9b93c..5e5607b5 100644 --- a/jubatus/core/driver/classifier.hpp +++ b/jubatus/core/driver/classifier.hpp @@ -21,6 +21,7 @@ #include #include #include "jubatus/util/lang/shared_ptr.h" +#include "../common/byte_buffer.hpp" #include "../classifier/classifier_type.hpp" #include "../framework/mixable.hpp" #include "../fv_converter/mixable_weight_manager.hpp" @@ -60,6 +61,9 @@ class classifier : public driver_base { void pack(framework::packer& pk) const; void unpack(msgpack::object o); + void import_model(common::byte_buffer& from) const; + common::byte_buffer export_model() const; + std::vector get_labels() const; bool set_label(const std::string& label); diff --git a/jubatus/core/fv_converter/datum_to_fv_converter.cpp b/jubatus/core/fv_converter/datum_to_fv_converter.cpp index 57103d20..d3b5b27e 100644 --- a/jubatus/core/fv_converter/datum_to_fv_converter.cpp +++ b/jubatus/core/fv_converter/datum_to_fv_converter.cpp @@ -343,6 +343,10 @@ class datum_to_fv_converter_impl { } } + void export_model(framework::packer& pk) { + + } + private: void filter_strings( const datum::sv_t& string_values, @@ -658,6 +662,11 @@ void datum_to_fv_converter::clear_weights() { pimpl_->clear_weights(); } +void datum_to_fv_converter::export_model(framework::packer& pk) { + pimpl_->export_model(pk); +} + + } // namespace fv_converter } // namespace core } // namespace jubatus diff --git a/jubatus/core/fv_converter/datum_to_fv_converter.hpp b/jubatus/core/fv_converter/datum_to_fv_converter.hpp index 20f96ff6..a1e78bee 100644 --- a/jubatus/core/fv_converter/datum_to_fv_converter.hpp +++ b/jubatus/core/fv_converter/datum_to_fv_converter.hpp @@ -25,6 +25,7 @@ #include "jubatus/util/lang/scoped_ptr.h" #include "../common/type.hpp" #include "../framework/mixable.hpp" +#include "../framework/packer.hpp" namespace jubatus { namespace core { @@ -120,6 +121,9 @@ class datum_to_fv_converter { void set_weight_manager(jubatus::util::lang::shared_ptr wm); void clear_weights(); + void import_model(framework::packer& pk); + void export_model(framework::packer& pk); + private: jubatus::util::lang::scoped_ptr pimpl_; }; diff --git a/jubatus/core/fv_converter/weight_manager.hpp b/jubatus/core/fv_converter/weight_manager.hpp index 0e854519..bdf731cc 100644 --- a/jubatus/core/fv_converter/weight_manager.hpp +++ b/jubatus/core/fv_converter/weight_manager.hpp @@ -26,6 +26,7 @@ #include "../framework/model.hpp" #include "../common/type.hpp" #include "../common/version.hpp" +#include "../common/export_model.hpp" #include "keyword_weights.hpp" namespace jubatus { @@ -89,6 +90,8 @@ class weight_manager : public framework::model { } MSGPACK_DEFINE(version_, diff_weights_, master_weights_); + JUBATUS_EXPORT_MODEL(version_, master_weights_); + void import_model(msgpack::object o); void pack(framework::packer& pk) const { pk.pack(*this); diff --git a/jubatus/core/storage/local_storage.cpp b/jubatus/core/storage/local_storage.cpp index 263466af..930032e2 100644 --- a/jubatus/core/storage/local_storage.cpp +++ b/jubatus/core/storage/local_storage.cpp @@ -221,6 +221,10 @@ void local_storage::unpack(msgpack::object o) { o.convert(this); } +void local_storage::import_model(msgpack::object o) { + o.convert(this); +} + std::string local_storage::type() const { return "local_storage"; } diff --git a/jubatus/core/storage/local_storage.hpp b/jubatus/core/storage/local_storage.hpp index d25b46c0..b5a9fd66 100644 --- a/jubatus/core/storage/local_storage.hpp +++ b/jubatus/core/storage/local_storage.hpp @@ -24,6 +24,7 @@ #include "storage_base.hpp" #include "../common/key_manager.hpp" #include "../common/version.hpp" +#include "../common/export_model.hpp" namespace jubatus { namespace core { @@ -80,12 +81,15 @@ class local_storage : public storage_base { void pack(framework::packer& packer) const; void unpack(msgpack::object o); + storage::version get_version() const { return storage::version(); } std::string type() const; MSGPACK_DEFINE(tbl_, class2id_); + JUBATUS_EXPORT_MODEL(tbl_, class2id_); + void import_model(msgpack::object o); private: // map_features3_t tbl_; diff --git a/jubatus/core/storage/local_storage_mixture.hpp b/jubatus/core/storage/local_storage_mixture.hpp index c25d3553..22b1c7d9 100644 --- a/jubatus/core/storage/local_storage_mixture.hpp +++ b/jubatus/core/storage/local_storage_mixture.hpp @@ -88,6 +88,8 @@ class local_storage_mixture : public storage_base { std::string type() const; MSGPACK_DEFINE(tbl_, class2id_, tbl_diff_, model_version_); + JUBATUS_EXPORT_MODEL(tbl_, class2id_, model_version_); + void import_model(msgpack::object o); private: bool get_internal(const std::string& feature, id_feature_val3_t& ret) const; diff --git a/jubatus/core/storage/storage_base.hpp b/jubatus/core/storage/storage_base.hpp index f5ed4e4c..4b214214 100644 --- a/jubatus/core/storage/storage_base.hpp +++ b/jubatus/core/storage/storage_base.hpp @@ -29,6 +29,7 @@ #include "../common/exception.hpp" #include "../common/type.hpp" #include "../framework/model.hpp" +#include "../framework/packer.hpp" namespace jubatus { namespace core { @@ -63,6 +64,8 @@ class storage_base : public framework::model { virtual void pack(framework::packer& packer) const = 0; virtual void unpack(msgpack::object o) = 0; + virtual void export_model(framework::packer& pk) const = 0; + virtual void import_model(msgpack::object o) = 0; virtual version get_version() const = 0;