Skip to content

Commit

Permalink
implement export_model of classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
KUMAZAKI Hiroki committed Mar 19, 2015
1 parent 39ae1b1 commit a424712
Show file tree
Hide file tree
Showing 16 changed files with 146 additions and 21 deletions.
2 changes: 2 additions & 0 deletions jubatus/core/classifier/classifier_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class classifier_base {

virtual void pack(framework::packer& pk) const = 0;
virtual void unpack(msgpack::object o) = 0;
virtual void export_model(framework::packer& pk) const = 0;
virtual void import_model(msgpack::object o) = 0;
virtual void clear() = 0;

virtual framework::mixable* get_mixable() = 0;
Expand Down
7 changes: 7 additions & 0 deletions jubatus/core/classifier/linear_classifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ void linear_classifier::pack(framework::packer& pk) const {
void linear_classifier::unpack(msgpack::object o) {
storage_->unpack(o);
}
void linear_classifier::export_model(framework::packer& pk) const {
// TODO
}

void linear_classifier::import_model(msgpack::object o) {
// TODO
}

framework::mixable* linear_classifier::get_mixable() {
return &mixable_storage_;
Expand Down
2 changes: 2 additions & 0 deletions jubatus/core/classifier/linear_classifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class linear_classifier : public classifier_base {

void pack(framework::packer& pk) const;
void unpack(msgpack::object o);
void export_model(framework::packer& pk) const;
void import_model(msgpack::object o);

framework::mixable* get_mixable();

Expand Down
8 changes: 8 additions & 0 deletions jubatus/core/classifier/nearest_neighbor_classifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,14 @@ void nearest_neighbor_classifier::unpack(msgpack::object o) {
}
}

void nearest_neighbor_classifier::export_model(framework::packer& pk) const {
// TODO
}
void nearest_neighbor_classifier::import_model(msgpack::object o) {
// TODO
}


framework::mixable* nearest_neighbor_classifier::get_mixable() {
return nearest_neighbor_engine_->get_mixable();
}
Expand Down
2 changes: 2 additions & 0 deletions jubatus/core/classifier/nearest_neighbor_classifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class nearest_neighbor_classifier : public classifier_base {

void pack(framework::packer& pk) const;
void unpack(msgpack::object o);
void export_model(framework::packer& pk) const;
void import_model(msgpack::object o);

framework::mixable* get_mixable();

Expand Down
7 changes: 7 additions & 0 deletions jubatus/core/common/byte_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <stdint.h>
#include <vector>
#include <cstring>
#include <memory>
#include <msgpack.hpp>
#include "jubatus/util/lang/shared_ptr.h"

Expand All @@ -41,6 +42,12 @@ class byte_buffer {
buf_.reset(new std::vector<char>(first, first+size));
}

void write(const char* buf, unsigned int len) {
const size_t old_tail = buf_->size();
buf_->resize(buf_->size() + len);
std::memcpy(buf_->data() + old_tail, buf, len);
}

// following member functions are implicily defined:
// byte_buffer(const byte_buffer& b) = default;
// byte_buffer& operator=(const byte_buffer& b) = default;
Expand Down
28 changes: 28 additions & 0 deletions jubatus/core/common/export_model.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Jubatus: Online machine learning framework for distributed environment
// Copyright (C) 2014 Preferred Infrastructure and Nippon Telegraph and Telephone Corporation.
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License version 2.1 as published by the Free Software Foundation.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

#ifndef JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP__
#define JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP__

#include <msgpack.hpp>
#include "../framework/packer.hpp"

#define JUBATUS_EXPORT_MODEL(...) \
void export_model(framework::packer& pk) const { \
msgpack::type::make_define(__VA_ARGS__).msgpack_pack(pk); \
}

#endif // JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP_
78 changes: 57 additions & 21 deletions jubatus/core/driver/classifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

#include "classifier.hpp"

Expand All @@ -24,13 +24,17 @@
#include "../classifier/classifier_factory.hpp"
#include "../classifier/classifier_base.hpp"
#include "../common/vector_util.hpp"
#include "../common/byte_buffer.hpp"
#include "../framework/stream_writer.hpp"
#include "../framework/packer.hpp"
#include "../fv_converter/datum.hpp"
#include "../fv_converter/datum_to_fv_converter.hpp"
#include "../fv_converter/converter_config.hpp"
#include "../storage/storage_factory.hpp"

using std::string;
using std::vector;
using std::make_pair;
using jubatus::util::lang::shared_ptr;
using jubatus::core::fv_converter::weight_manager;
using jubatus::core::fv_converter::mixable_weight_manager;
Expand All @@ -55,40 +59,40 @@ classifier::~classifier() {
}

void classifier::train(const string& label, const fv_converter::datum& data) {
common::sfv_t v;
converter_->convert_and_update_weight(data, v);
common::sort_and_merge(v);
classifier_->train(v, label);
common::sfv_t v;
converter_->convert_and_update_weight(data, v);
common::sort_and_merge(v);
classifier_->train(v, label);
}

jubatus::core::classifier::classify_result classifier::classify(
const fv_converter::datum& data) const {
common::sfv_t v;
converter_->convert(data, v);
const fv_converter::datum& data) const {
common::sfv_t v;
converter_->convert(data, v);

jubatus::core::classifier::classify_result scores;
classifier_->classify_with_scores(v, scores);
return scores;
jubatus::core::classifier::classify_result scores;
classifier_->classify_with_scores(v, scores);
return scores;
}

void classifier::get_status(std::map<string, string>& status) const {
classifier_->get_status(status);
classifier_->get_status(status);
}

bool classifier::delete_label(const std::string& label) {
return classifier_->delete_label(label);
return classifier_->delete_label(label);
}

void classifier::clear() {
classifier_->clear();
converter_->clear_weights();
classifier_->clear();
converter_->clear_weights();
}

std::vector<std::string> classifier::get_labels() const {
return classifier_->get_labels();
return classifier_->get_labels();
}
bool classifier::set_label(const std::string& label) {
return classifier_->set_label(label);
return classifier_->set_label(label);
}

void classifier::pack(framework::packer& pk) const {
Expand All @@ -109,6 +113,38 @@ void classifier::unpack(msgpack::object o) {
wm_.get_model()->unpack(o.via.array.ptr[1]);
}

} // namespace driver
} // namespace core
} // namespace jubatus
struct versioned_model {
// this struct is version compatible
std::string version;
common::byte_buffer buffer;
versioned_model(const std::string& v, const common::byte_buffer& b)
: version(v), buffer(b) {
}
MSGPACK_DEFINE(version, buffer);
};

void classifier::import_model(common::byte_buffer& from) const {
// not implemented yet
}

common::byte_buffer classifier::export_model() const {
common::byte_buffer model;
framework::stream_writer<common::byte_buffer> model_writer(model);
core::framework::jubatus_packer jp(model_writer);
core::framework::packer pk(jp);
// msgpack::packer<common::byte_buffer> pk(model);
classifier_->export_model(pk);
wm_->export_model(pk);
{
common::byte_buffer ret;
msgpack::pack(ret,
versioned_model(JUBATUS_CORE_VERSION,
model));
return ret;
}
}


} // namespace driver
} // namespace core
} // namespace jubatus
4 changes: 4 additions & 0 deletions jubatus/core/driver/classifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <string>
#include <vector>
#include "jubatus/util/lang/shared_ptr.h"
#include "../common/byte_buffer.hpp"
#include "../classifier/classifier_type.hpp"
#include "../framework/mixable.hpp"
#include "../fv_converter/mixable_weight_manager.hpp"
Expand Down Expand Up @@ -60,6 +61,9 @@ class classifier : public driver_base {
void pack(framework::packer& pk) const;
void unpack(msgpack::object o);

void import_model(common::byte_buffer& from) const;
common::byte_buffer export_model() const;

std::vector<std::string> get_labels() const;
bool set_label(const std::string& label);

Expand Down
9 changes: 9 additions & 0 deletions jubatus/core/fv_converter/datum_to_fv_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ class datum_to_fv_converter_impl {
}
}

void export_model(framework::packer& pk) {

}

private:
void filter_strings(
const datum::sv_t& string_values,
Expand Down Expand Up @@ -658,6 +662,11 @@ void datum_to_fv_converter::clear_weights() {
pimpl_->clear_weights();
}

void datum_to_fv_converter::export_model(framework::packer& pk) {
pimpl_->export_model(pk);
}


} // namespace fv_converter
} // namespace core
} // namespace jubatus
4 changes: 4 additions & 0 deletions jubatus/core/fv_converter/datum_to_fv_converter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "jubatus/util/lang/scoped_ptr.h"
#include "../common/type.hpp"
#include "../framework/mixable.hpp"
#include "../framework/packer.hpp"

namespace jubatus {
namespace core {
Expand Down Expand Up @@ -120,6 +121,9 @@ class datum_to_fv_converter {
void set_weight_manager(jubatus::util::lang::shared_ptr<weight_manager> wm);
void clear_weights();

void import_model(framework::packer& pk);
void export_model(framework::packer& pk);

private:
jubatus::util::lang::scoped_ptr<datum_to_fv_converter_impl> pimpl_;
};
Expand Down
3 changes: 3 additions & 0 deletions jubatus/core/fv_converter/weight_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "../framework/model.hpp"
#include "../common/type.hpp"
#include "../common/version.hpp"
#include "../common/export_model.hpp"
#include "keyword_weights.hpp"

namespace jubatus {
Expand Down Expand Up @@ -89,6 +90,8 @@ class weight_manager : public framework::model {
}

MSGPACK_DEFINE(version_, diff_weights_, master_weights_);
JUBATUS_EXPORT_MODEL(version_, master_weights_);
void import_model(msgpack::object o);

void pack(framework::packer& pk) const {
pk.pack(*this);
Expand Down
4 changes: 4 additions & 0 deletions jubatus/core/storage/local_storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ void local_storage::unpack(msgpack::object o) {
o.convert(this);
}

void local_storage::import_model(msgpack::object o) {
o.convert(this);
}

std::string local_storage::type() const {
return "local_storage";
}
Expand Down
4 changes: 4 additions & 0 deletions jubatus/core/storage/local_storage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "storage_base.hpp"
#include "../common/key_manager.hpp"
#include "../common/version.hpp"
#include "../common/export_model.hpp"

namespace jubatus {
namespace core {
Expand Down Expand Up @@ -80,12 +81,15 @@ class local_storage : public storage_base {

void pack(framework::packer& packer) const;
void unpack(msgpack::object o);

storage::version get_version() const {
return storage::version();
}
std::string type() const;

MSGPACK_DEFINE(tbl_, class2id_);
JUBATUS_EXPORT_MODEL(tbl_, class2id_);
void import_model(msgpack::object o);

private:
// map_features3_t tbl_;
Expand Down
2 changes: 2 additions & 0 deletions jubatus/core/storage/local_storage_mixture.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class local_storage_mixture : public storage_base {
std::string type() const;

MSGPACK_DEFINE(tbl_, class2id_, tbl_diff_, model_version_);
JUBATUS_EXPORT_MODEL(tbl_, class2id_, model_version_);
void import_model(msgpack::object o);

private:
bool get_internal(const std::string& feature, id_feature_val3_t& ret) const;
Expand Down
3 changes: 3 additions & 0 deletions jubatus/core/storage/storage_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "../common/exception.hpp"
#include "../common/type.hpp"
#include "../framework/model.hpp"
#include "../framework/packer.hpp"

namespace jubatus {
namespace core {
Expand Down Expand Up @@ -63,6 +64,8 @@ class storage_base : public framework::model {

virtual void pack(framework::packer& packer) const = 0;
virtual void unpack(msgpack::object o) = 0;
virtual void export_model(framework::packer& pk) const = 0;
virtual void import_model(msgpack::object o) = 0;

virtual version get_version() const = 0;

Expand Down

0 comments on commit a424712

Please sign in to comment.