From 3c1a09ce697ea3a101bb7d855b7aed48cce472ed Mon Sep 17 00:00:00 2001 From: KUMAZAKI Hiroki Date: Thu, 12 Jun 2014 21:18:10 +0900 Subject: [PATCH] export_model interface of driver::classifier --- jubatus/core/classifier/classifier_base.hpp | 2 + jubatus/core/classifier/linear_classifier.cpp | 13 +++ jubatus/core/classifier/linear_classifier.hpp | 6 +- .../nearest_neighbor_classifier.cpp | 13 +++ .../nearest_neighbor_classifier.hpp | 4 + jubatus/core/common/byte_buffer.hpp | 7 ++ jubatus/core/common/export_model.hpp | 36 +++++++ jubatus/core/common/key_manager.hpp | 6 ++ jubatus/core/driver/classifier.cpp | 96 +++++++++++++++---- jubatus/core/driver/classifier.hpp | 4 + jubatus/core/driver/recommender.cpp | 6 ++ jubatus/core/driver/recommender.hpp | 3 + .../core/framework/linear_function_mixer.hpp | 3 + .../fv_converter/datum_to_fv_converter.cpp | 1 + .../fv_converter/datum_to_fv_converter.hpp | 1 + jubatus/core/fv_converter/weight_manager.hpp | 3 + jubatus/core/storage/local_storage.cpp | 4 + jubatus/core/storage/local_storage.hpp | 4 + .../core/storage/local_storage_mixture.hpp | 1 + jubatus/core/storage/storage_base.hpp | 3 + jubatus/core/unlearner/lru_unlearner.cpp | 11 +++ jubatus/core/unlearner/lru_unlearner.hpp | 7 +- jubatus/core/unlearner/random_unlearner.hpp | 5 + jubatus/core/unlearner/unlearner_base.hpp | 7 ++ 24 files changed, 221 insertions(+), 25 deletions(-) create mode 100644 jubatus/core/common/export_model.hpp diff --git a/jubatus/core/classifier/classifier_base.hpp b/jubatus/core/classifier/classifier_base.hpp index 826a5bb4..ec98bd43 100644 --- a/jubatus/core/classifier/classifier_base.hpp +++ b/jubatus/core/classifier/classifier_base.hpp @@ -48,6 +48,8 @@ class classifier_base { virtual void set_label_unlearner( jubatus::util::lang::shared_ptr label_unlearner) = 0; + virtual jubatus::util::lang::shared_ptr + get_label_unlearner() const = 0; virtual bool delete_label(const std::string& label) = 0; virtual std::vector get_labels() const = 0; diff --git a/jubatus/core/classifier/linear_classifier.cpp b/jubatus/core/classifier/linear_classifier.cpp index ea4723f0..d79dfdc6 100644 --- a/jubatus/core/classifier/linear_classifier.cpp +++ b/jubatus/core/classifier/linear_classifier.cpp @@ -64,6 +64,12 @@ void linear_classifier::set_label_unlearner( unlearner_ = label_unlearner; } +jubatus::util::lang::shared_ptr +linear_classifier::get_label_unlearner() const { + return unlearner_; +} + + void linear_classifier::classify_with_scores( const common::sfv_t& sfv, classify_result& scores) const { @@ -207,6 +213,13 @@ void linear_classifier::pack(framework::packer& pk) const { void linear_classifier::unpack(msgpack::object o) { storage_->unpack(o); } +void linear_classifier::export_model(framework::packer& pk) const { + // TODO +} + +void linear_classifier::import_model(msgpack::object o) { + // TODO +} framework::mixable* linear_classifier::get_mixable() { return &mixable_storage_; diff --git a/jubatus/core/classifier/linear_classifier.hpp b/jubatus/core/classifier/linear_classifier.hpp index 020ea160..0826d0a5 100644 --- a/jubatus/core/classifier/linear_classifier.hpp +++ b/jubatus/core/classifier/linear_classifier.hpp @@ -46,9 +46,7 @@ class linear_classifier : public classifier_base { label_unlearner); jubatus::util::lang::shared_ptr - label_unlearner() const { - return unlearner_; - } + get_label_unlearner() const; std::string classify(const common::sfv_t& fv) const; void classify_with_scores(const common::sfv_t& fv, @@ -69,6 +67,8 @@ class linear_classifier : public classifier_base { void pack(framework::packer& pk) const; void unpack(msgpack::object o); + void export_model(framework::packer& pk) const; + void import_model(msgpack::object o); framework::mixable* get_mixable(); diff --git a/jubatus/core/classifier/nearest_neighbor_classifier.cpp b/jubatus/core/classifier/nearest_neighbor_classifier.cpp index dc4cdcdf..87f87a71 100644 --- a/jubatus/core/classifier/nearest_neighbor_classifier.cpp +++ b/jubatus/core/classifier/nearest_neighbor_classifier.cpp @@ -100,6 +100,11 @@ void nearest_neighbor_classifier::set_label_unlearner( unlearner_ = label_unlearner; } +shared_ptr +nearest_neighbor_classifier::get_label_unlearner() const { + return unlearner_; +} + std::string nearest_neighbor_classifier::classify( const common::sfv_t& fv) const { classify_result result; @@ -225,6 +230,14 @@ void nearest_neighbor_classifier::unpack(msgpack::object o) { } } +void nearest_neighbor_classifier::export_model(framework::packer& pk) const { + // TODO +} +void nearest_neighbor_classifier::import_model(msgpack::object o) { + // TODO +} + + framework::mixable* nearest_neighbor_classifier::get_mixable() { return nearest_neighbor_engine_->get_mixable(); } diff --git a/jubatus/core/classifier/nearest_neighbor_classifier.hpp b/jubatus/core/classifier/nearest_neighbor_classifier.hpp index 4f47ff91..1d17aec5 100644 --- a/jubatus/core/classifier/nearest_neighbor_classifier.hpp +++ b/jubatus/core/classifier/nearest_neighbor_classifier.hpp @@ -47,6 +47,8 @@ class nearest_neighbor_classifier : public classifier_base { void set_label_unlearner( jubatus::util::lang::shared_ptr label_unlearner); + jubatus::util::lang::shared_ptr + get_label_unlearner() const; std::string classify(const common::sfv_t& fv) const; void classify_with_scores(const common::sfv_t& fv, @@ -63,6 +65,8 @@ class nearest_neighbor_classifier : public classifier_base { void pack(framework::packer& pk) const; void unpack(msgpack::object o); + void export_model(framework::packer& pk) const; + void import_model(msgpack::object o); framework::mixable* get_mixable(); diff --git a/jubatus/core/common/byte_buffer.hpp b/jubatus/core/common/byte_buffer.hpp index 9becfbf7..2a4fe4be 100644 --- a/jubatus/core/common/byte_buffer.hpp +++ b/jubatus/core/common/byte_buffer.hpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "jubatus/util/lang/shared_ptr.h" @@ -41,6 +42,12 @@ class byte_buffer { buf_.reset(new std::vector(first, first+size)); } + void write(const char* buf, unsigned int len) { + const size_t old_tail = buf_->size(); + buf_->resize(buf_->size() + len); + std::memcpy(buf_->data() + old_tail, buf, len); + } + // following member functions are implicily defined: // byte_buffer(const byte_buffer& b) = default; // byte_buffer& operator=(const byte_buffer& b) = default; diff --git a/jubatus/core/common/export_model.hpp b/jubatus/core/common/export_model.hpp new file mode 100644 index 00000000..b353cda6 --- /dev/null +++ b/jubatus/core/common/export_model.hpp @@ -0,0 +1,36 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2014 Preferred Infrastructure and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + +#ifndef JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP__ +#define JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP__ + +#include +#include "../framework/packer.hpp" + +#define JUBATUS_EXPORT_MODEL(...) \ + void export_model(framework::packer& pk) const { \ + msgpack::type::make_define(__VA_ARGS__).msgpack_pack(pk); \ + } +#define JUBATUS_IMPORT_MODEL(...) \ + void import_model(msgpack::object o) { \ + this->clear(); \ + msgpack::type::make_define(__VA_ARGS__).msgpack_unpack(o); \ + } + +#define JUBATUS_PORTING_MODEL(...) \ + JUBATUS_IMPORT_MODEL(__VA_ARGS__) \ + JUBATUS_EXPORT_MODEL(__VA_ARGS__) +#endif // JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP_ diff --git a/jubatus/core/common/key_manager.hpp b/jubatus/core/common/key_manager.hpp index 7f899607..85293f90 100644 --- a/jubatus/core/common/key_manager.hpp +++ b/jubatus/core/common/key_manager.hpp @@ -46,6 +46,12 @@ class key_manager { key2id_.swap(km.key2id_); id2key_.swap(km.id2key_); } + key_manager& operator=(const key_manager& rhs) { + key2id_ = rhs.key2id_; + id2key_ = rhs.id2key_; + next_id_ = rhs.next_id_; + return *this; + } size_t size() const { return key2id_.size(); diff --git a/jubatus/core/driver/classifier.cpp b/jubatus/core/driver/classifier.cpp index b4293ca0..ff4d2820 100644 --- a/jubatus/core/driver/classifier.cpp +++ b/jubatus/core/driver/classifier.cpp @@ -7,12 +7,12 @@ // // This library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public // License along with this library; if not, write to the Free Software -// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA #include "classifier.hpp" @@ -24,6 +24,9 @@ #include "../classifier/classifier_factory.hpp" #include "../classifier/classifier_base.hpp" #include "../common/vector_util.hpp" +#include "../common/byte_buffer.hpp" +#include "../framework/stream_writer.hpp" +#include "../framework/packer.hpp" #include "../fv_converter/datum.hpp" #include "../fv_converter/datum_to_fv_converter.hpp" #include "../fv_converter/converter_config.hpp" @@ -31,6 +34,7 @@ using std::string; using std::vector; +using std::make_pair; using jubatus::util::lang::shared_ptr; using jubatus::core::fv_converter::weight_manager; using jubatus::core::fv_converter::mixable_weight_manager; @@ -55,40 +59,40 @@ classifier::~classifier() { } void classifier::train(const string& label, const fv_converter::datum& data) { - common::sfv_t v; - converter_->convert_and_update_weight(data, v); - common::sort_and_merge(v); - classifier_->train(v, label); + common::sfv_t v; + converter_->convert_and_update_weight(data, v); + common::sort_and_merge(v); + classifier_->train(v, label); } jubatus::core::classifier::classify_result classifier::classify( - const fv_converter::datum& data) const { - common::sfv_t v; - converter_->convert(data, v); + const fv_converter::datum& data) const { + common::sfv_t v; + converter_->convert(data, v); - jubatus::core::classifier::classify_result scores; - classifier_->classify_with_scores(v, scores); - return scores; + jubatus::core::classifier::classify_result scores; + classifier_->classify_with_scores(v, scores); + return scores; } void classifier::get_status(std::map& status) const { - classifier_->get_status(status); + classifier_->get_status(status); } bool classifier::delete_label(const std::string& label) { - return classifier_->delete_label(label); + return classifier_->delete_label(label); } void classifier::clear() { - classifier_->clear(); - converter_->clear_weights(); + classifier_->clear(); + converter_->clear_weights(); } std::vector classifier::get_labels() const { - return classifier_->get_labels(); + return classifier_->get_labels(); } bool classifier::set_label(const std::string& label) { - return classifier_->set_label(label); + return classifier_->set_label(label); } void classifier::pack(framework::packer& pk) const { @@ -109,6 +113,56 @@ void classifier::unpack(msgpack::object o) { wm_.get_model()->unpack(o.via.array.ptr[1]); } -} // namespace driver -} // namespace core -} // namespace jubatus +struct versioned_model { + // this struct is version compatible + std::string version; + common::byte_buffer buffer; + versioned_model(const std::string& v, const common::byte_buffer& b) + : version(v), buffer(b) { + } + MSGPACK_DEFINE(version, buffer); +}; + +void classifier::import_model(common::byte_buffer& from) const { + msgpack::unpacked packed; + msgpack::unpack(&packed, from.ptr(), from.size()); + msgpack::object o = packed.get(); + if (o.type != msgpack::type::ARRAY || o.via.array.size != 2) { + throw msgpack::type_error(); + } + const std::string& version = o.via.array.ptr[0].as(); + common::byte_buffer serialized_model = + o.via.array.ptr[1].as(); + + msgpack::unpacked unpacked_model; + msgpack::unpack(&unpacked_model, serialized_model.ptr(), serialized_model.size()); + msgpack::object model = unpacked_model.get(); + + if (version == "0.0.1") { + } else { + throw JUBATUS_EXCEPTION( + common::invalid_parameter("unknown version number: " + version)); + } +} + +common::byte_buffer classifier::export_model() const { + common::byte_buffer model; + framework::stream_writer model_writer(model); + core::framework::jubatus_packer jp(model_writer); + core::framework::packer pk(jp); +// msgpack::packer pk(model); + classifier_->pack(pk); + wm_.get_model()->pack(pk); + { + common::byte_buffer ret; + msgpack::pack(ret, + versioned_model(JUBATUS_CORE_VERSION, + model)); + return ret; + } +} + + +} // namespace driver +} // namespace core +} // namespace jubatus diff --git a/jubatus/core/driver/classifier.hpp b/jubatus/core/driver/classifier.hpp index 34e9b93c..5e5607b5 100644 --- a/jubatus/core/driver/classifier.hpp +++ b/jubatus/core/driver/classifier.hpp @@ -21,6 +21,7 @@ #include #include #include "jubatus/util/lang/shared_ptr.h" +#include "../common/byte_buffer.hpp" #include "../classifier/classifier_type.hpp" #include "../framework/mixable.hpp" #include "../fv_converter/mixable_weight_manager.hpp" @@ -60,6 +61,9 @@ class classifier : public driver_base { void pack(framework::packer& pk) const; void unpack(msgpack::object o); + void import_model(common::byte_buffer& from) const; + common::byte_buffer export_model() const; + std::vector get_labels() const; bool set_label(const std::string& label); diff --git a/jubatus/core/driver/recommender.cpp b/jubatus/core/driver/recommender.cpp index 42cc693c..581b4f69 100644 --- a/jubatus/core/driver/recommender.cpp +++ b/jubatus/core/driver/recommender.cpp @@ -161,6 +161,12 @@ void recommender::unpack(msgpack::object o) { wm_.get_model()->unpack(o.via.array.ptr[1]); } +void recommender::import_model(common::byte_buffer& from) const { + +} +common::byte_buffer recommender::export_model() const { +} + } // namespace driver } // namespace core } // namespace jubatus diff --git a/jubatus/core/driver/recommender.hpp b/jubatus/core/driver/recommender.hpp index 29e57b09..f9d549b5 100644 --- a/jubatus/core/driver/recommender.hpp +++ b/jubatus/core/driver/recommender.hpp @@ -71,6 +71,9 @@ class recommender : public driver_base { fv_converter::datum decode_row(const std::string& id); std::vector get_all_rows(); + void import_model(common::byte_buffer& from) const; + common::byte_buffer export_model() const; + private: jubatus::util::lang::shared_ptr converter_; diff --git a/jubatus/core/framework/linear_function_mixer.hpp b/jubatus/core/framework/linear_function_mixer.hpp index e9347f5b..9ad84d51 100644 --- a/jubatus/core/framework/linear_function_mixer.hpp +++ b/jubatus/core/framework/linear_function_mixer.hpp @@ -51,6 +51,9 @@ class linear_function_mixer : public linear_mixable { model_ptr get_model() const { return model_; } + void set_model(model_ptr model) { + model_ = model; + } void mix(const diffv& lhs, diffv& mixed) const; void get_diff(diffv&) const; diff --git a/jubatus/core/fv_converter/datum_to_fv_converter.cpp b/jubatus/core/fv_converter/datum_to_fv_converter.cpp index 57103d20..6dd9333f 100644 --- a/jubatus/core/fv_converter/datum_to_fv_converter.cpp +++ b/jubatus/core/fv_converter/datum_to_fv_converter.cpp @@ -658,6 +658,7 @@ void datum_to_fv_converter::clear_weights() { pimpl_->clear_weights(); } + } // namespace fv_converter } // namespace core } // namespace jubatus diff --git a/jubatus/core/fv_converter/datum_to_fv_converter.hpp b/jubatus/core/fv_converter/datum_to_fv_converter.hpp index 20f96ff6..c73d9fc5 100644 --- a/jubatus/core/fv_converter/datum_to_fv_converter.hpp +++ b/jubatus/core/fv_converter/datum_to_fv_converter.hpp @@ -25,6 +25,7 @@ #include "jubatus/util/lang/scoped_ptr.h" #include "../common/type.hpp" #include "../framework/mixable.hpp" +#include "../framework/packer.hpp" namespace jubatus { namespace core { diff --git a/jubatus/core/fv_converter/weight_manager.hpp b/jubatus/core/fv_converter/weight_manager.hpp index 0e854519..bdf731cc 100644 --- a/jubatus/core/fv_converter/weight_manager.hpp +++ b/jubatus/core/fv_converter/weight_manager.hpp @@ -26,6 +26,7 @@ #include "../framework/model.hpp" #include "../common/type.hpp" #include "../common/version.hpp" +#include "../common/export_model.hpp" #include "keyword_weights.hpp" namespace jubatus { @@ -89,6 +90,8 @@ class weight_manager : public framework::model { } MSGPACK_DEFINE(version_, diff_weights_, master_weights_); + JUBATUS_EXPORT_MODEL(version_, master_weights_); + void import_model(msgpack::object o); void pack(framework::packer& pk) const { pk.pack(*this); diff --git a/jubatus/core/storage/local_storage.cpp b/jubatus/core/storage/local_storage.cpp index 263466af..930032e2 100644 --- a/jubatus/core/storage/local_storage.cpp +++ b/jubatus/core/storage/local_storage.cpp @@ -221,6 +221,10 @@ void local_storage::unpack(msgpack::object o) { o.convert(this); } +void local_storage::import_model(msgpack::object o) { + o.convert(this); +} + std::string local_storage::type() const { return "local_storage"; } diff --git a/jubatus/core/storage/local_storage.hpp b/jubatus/core/storage/local_storage.hpp index d25b46c0..b5a9fd66 100644 --- a/jubatus/core/storage/local_storage.hpp +++ b/jubatus/core/storage/local_storage.hpp @@ -24,6 +24,7 @@ #include "storage_base.hpp" #include "../common/key_manager.hpp" #include "../common/version.hpp" +#include "../common/export_model.hpp" namespace jubatus { namespace core { @@ -80,12 +81,15 @@ class local_storage : public storage_base { void pack(framework::packer& packer) const; void unpack(msgpack::object o); + storage::version get_version() const { return storage::version(); } std::string type() const; MSGPACK_DEFINE(tbl_, class2id_); + JUBATUS_EXPORT_MODEL(tbl_, class2id_); + void import_model(msgpack::object o); private: // map_features3_t tbl_; diff --git a/jubatus/core/storage/local_storage_mixture.hpp b/jubatus/core/storage/local_storage_mixture.hpp index c25d3553..801be3d8 100644 --- a/jubatus/core/storage/local_storage_mixture.hpp +++ b/jubatus/core/storage/local_storage_mixture.hpp @@ -88,6 +88,7 @@ class local_storage_mixture : public storage_base { std::string type() const; MSGPACK_DEFINE(tbl_, class2id_, tbl_diff_, model_version_); + JUBATUS_PORTING_MODEL(tbl_, class2id_, model_version_); private: bool get_internal(const std::string& feature, id_feature_val3_t& ret) const; diff --git a/jubatus/core/storage/storage_base.hpp b/jubatus/core/storage/storage_base.hpp index f5ed4e4c..4b214214 100644 --- a/jubatus/core/storage/storage_base.hpp +++ b/jubatus/core/storage/storage_base.hpp @@ -29,6 +29,7 @@ #include "../common/exception.hpp" #include "../common/type.hpp" #include "../framework/model.hpp" +#include "../framework/packer.hpp" namespace jubatus { namespace core { @@ -63,6 +64,8 @@ class storage_base : public framework::model { virtual void pack(framework::packer& packer) const = 0; virtual void unpack(msgpack::object o) = 0; + virtual void export_model(framework::packer& pk) const = 0; + virtual void import_model(msgpack::object o) = 0; virtual version get_version() const = 0; diff --git a/jubatus/core/unlearner/lru_unlearner.cpp b/jubatus/core/unlearner/lru_unlearner.cpp index fadffe7a..19157bd2 100644 --- a/jubatus/core/unlearner/lru_unlearner.cpp +++ b/jubatus/core/unlearner/lru_unlearner.cpp @@ -23,6 +23,7 @@ #include "../fv_converter/key_matcher.hpp" #include "../fv_converter/key_matcher_factory.hpp" #include "../common/exception.hpp" +#include "../common/unordered_set.hpp" using jubatus::util::data::unordered_set; using jubatus::core::fv_converter::key_matcher_factory; @@ -126,6 +127,16 @@ bool lru_unlearner::exists_in_memory(const std::string& id) const { return entry_map_.count(id) > 0 || sticky_ids_.count(id) > 0; } +void lru_unlearner::import_model(msgpack::object o) { + this->clear(); + msgpack::type::make_define(lru_).msgpack_unpack(o); + for (lru::iterator it = lru_.begin(); + it != lru_.end(); + ++it) { + entry_map_[*it] = it; + } +} + // private void lru_unlearner::rebuild_entry_map() { diff --git a/jubatus/core/unlearner/lru_unlearner.hpp b/jubatus/core/unlearner/lru_unlearner.hpp index 13b8cf37..8c4ce672 100644 --- a/jubatus/core/unlearner/lru_unlearner.hpp +++ b/jubatus/core/unlearner/lru_unlearner.hpp @@ -26,6 +26,8 @@ #include "jubatus/util/data/optional.h" #include "jubatus/util/lang/shared_ptr.h" #include "unlearner_base.hpp" +#include "../common/porting_model.hpp" +#include "../common/unordered_map.hpp" namespace jubatus { namespace core { @@ -54,7 +56,6 @@ class lru_unlearner : public unlearner_base { void clear() { lru_.clear(); entry_map_.clear(); - sticky_ids_.clear(); } explicit lru_unlearner(const config& conf); @@ -64,6 +65,10 @@ class lru_unlearner : public unlearner_base { bool remove(const std::string& id); bool exists_in_memory(const std::string& id) const; + JUBATUS_EXPORT_MODEL(lru_); + // CAUTION!: JUBATUS_IMPORT_MODEL should be hand-written + void import_model(msgpack::object o); + private: typedef std::list lru; typedef jubatus::util::data::unordered_map diff --git a/jubatus/core/unlearner/random_unlearner.hpp b/jubatus/core/unlearner/random_unlearner.hpp index b26f3170..94dcbed7 100644 --- a/jubatus/core/unlearner/random_unlearner.hpp +++ b/jubatus/core/unlearner/random_unlearner.hpp @@ -19,11 +19,14 @@ #include #include +#include "msgpack.hpp" #include "jubatus/util/data/optional.h" #include "jubatus/util/data/serialization.h" #include "jubatus/util/data/unordered_map.h" #include "jubatus/util/math/random.h" #include "unlearner_base.hpp" +#include "../common/porting_model.hpp" +#include "../common/unordered_set.hpp" namespace jubatus { namespace core { @@ -58,6 +61,8 @@ class random_unlearner : public unlearner_base { bool remove(const std::string& id); bool exists_in_memory(const std::string& id) const; + JUBATUS_PORTING_MODEL(id_set_, ids_, max_size_); + private: /** * Map of ID and its position in ids_. diff --git a/jubatus/core/unlearner/unlearner_base.hpp b/jubatus/core/unlearner/unlearner_base.hpp index 39051aca..a92a2618 100644 --- a/jubatus/core/unlearner/unlearner_base.hpp +++ b/jubatus/core/unlearner/unlearner_base.hpp @@ -19,6 +19,7 @@ #include #include "jubatus/util/lang/function.h" +#include "jubatus/core/framework/packer.hpp" namespace jubatus { namespace core { @@ -72,6 +73,12 @@ class unlearner_base { // touched and not unlearned since then, it returns true. virtual bool exists_in_memory(const std::string& id) const = 0; + // Export the innner model as msgpack format + virtual void export_model(framework::jubatus_packer& pk) const = 0; + + // Overwrite the innner model from serialised data + virtual void import_model(msgpack::object o) = 0; + protected: void unlearn(const std::string& id) const { callback_(id);