Skip to content

Commit

Permalink
export_model interface of driver::classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
KUMAZAKI Hiroki committed Mar 19, 2015
1 parent 39ae1b1 commit 3c1a09c
Show file tree
Hide file tree
Showing 24 changed files with 221 additions and 25 deletions.
2 changes: 2 additions & 0 deletions jubatus/core/classifier/classifier_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class classifier_base {
virtual void set_label_unlearner(
jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
label_unlearner) = 0;
virtual jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
get_label_unlearner() const = 0;

virtual bool delete_label(const std::string& label) = 0;
virtual std::vector<std::string> get_labels() const = 0;
Expand Down
13 changes: 13 additions & 0 deletions jubatus/core/classifier/linear_classifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ void linear_classifier::set_label_unlearner(
unlearner_ = label_unlearner;
}

jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
linear_classifier::get_label_unlearner() const {
return unlearner_;
}


void linear_classifier::classify_with_scores(
const common::sfv_t& sfv,
classify_result& scores) const {
Expand Down Expand Up @@ -207,6 +213,13 @@ void linear_classifier::pack(framework::packer& pk) const {
void linear_classifier::unpack(msgpack::object o) {
storage_->unpack(o);
}
void linear_classifier::export_model(framework::packer& pk) const {
// TODO
}

void linear_classifier::import_model(msgpack::object o) {
// TODO
}

framework::mixable* linear_classifier::get_mixable() {
return &mixable_storage_;
Expand Down
6 changes: 3 additions & 3 deletions jubatus/core/classifier/linear_classifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ class linear_classifier : public classifier_base {
label_unlearner);

jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
label_unlearner() const {
return unlearner_;
}
get_label_unlearner() const;

std::string classify(const common::sfv_t& fv) const;
void classify_with_scores(const common::sfv_t& fv,
Expand All @@ -69,6 +67,8 @@ class linear_classifier : public classifier_base {

void pack(framework::packer& pk) const;
void unpack(msgpack::object o);
void export_model(framework::packer& pk) const;
void import_model(msgpack::object o);

framework::mixable* get_mixable();

Expand Down
13 changes: 13 additions & 0 deletions jubatus/core/classifier/nearest_neighbor_classifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ void nearest_neighbor_classifier::set_label_unlearner(
unlearner_ = label_unlearner;
}

shared_ptr<unlearner::unlearner_base>
nearest_neighbor_classifier::get_label_unlearner() const {
return unlearner_;
}

std::string nearest_neighbor_classifier::classify(
const common::sfv_t& fv) const {
classify_result result;
Expand Down Expand Up @@ -225,6 +230,14 @@ void nearest_neighbor_classifier::unpack(msgpack::object o) {
}
}

void nearest_neighbor_classifier::export_model(framework::packer& pk) const {
// TODO
}
void nearest_neighbor_classifier::import_model(msgpack::object o) {
// TODO
}


framework::mixable* nearest_neighbor_classifier::get_mixable() {
return nearest_neighbor_engine_->get_mixable();
}
Expand Down
4 changes: 4 additions & 0 deletions jubatus/core/classifier/nearest_neighbor_classifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class nearest_neighbor_classifier : public classifier_base {
void set_label_unlearner(
jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
label_unlearner);
jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
get_label_unlearner() const;

std::string classify(const common::sfv_t& fv) const;
void classify_with_scores(const common::sfv_t& fv,
Expand All @@ -63,6 +65,8 @@ class nearest_neighbor_classifier : public classifier_base {

void pack(framework::packer& pk) const;
void unpack(msgpack::object o);
void export_model(framework::packer& pk) const;
void import_model(msgpack::object o);

framework::mixable* get_mixable();

Expand Down
7 changes: 7 additions & 0 deletions jubatus/core/common/byte_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <stdint.h>
#include <vector>
#include <cstring>
#include <memory>
#include <msgpack.hpp>
#include "jubatus/util/lang/shared_ptr.h"

Expand All @@ -41,6 +42,12 @@ class byte_buffer {
buf_.reset(new std::vector<char>(first, first+size));
}

void write(const char* buf, unsigned int len) {
const size_t old_tail = buf_->size();
buf_->resize(buf_->size() + len);
std::memcpy(buf_->data() + old_tail, buf, len);
}

// following member functions are implicily defined:
// byte_buffer(const byte_buffer& b) = default;
// byte_buffer& operator=(const byte_buffer& b) = default;
Expand Down
36 changes: 36 additions & 0 deletions jubatus/core/common/export_model.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Jubatus: Online machine learning framework for distributed environment
// Copyright (C) 2014 Preferred Infrastructure and Nippon Telegraph and Telephone Corporation.
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License version 2.1 as published by the Free Software Foundation.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

#ifndef JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP__
#define JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP__

#include <msgpack.hpp>
#include "../framework/packer.hpp"

#define JUBATUS_EXPORT_MODEL(...) \
void export_model(framework::packer& pk) const { \
msgpack::type::make_define(__VA_ARGS__).msgpack_pack(pk); \
}
#define JUBATUS_IMPORT_MODEL(...) \
void import_model(msgpack::object o) { \
this->clear(); \
msgpack::type::make_define(__VA_ARGS__).msgpack_unpack(o); \
}

#define JUBATUS_PORTING_MODEL(...) \
JUBATUS_IMPORT_MODEL(__VA_ARGS__) \
JUBATUS_EXPORT_MODEL(__VA_ARGS__)
#endif // JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP_
6 changes: 6 additions & 0 deletions jubatus/core/common/key_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ class key_manager {
key2id_.swap(km.key2id_);
id2key_.swap(km.id2key_);
}
key_manager& operator=(const key_manager& rhs) {
key2id_ = rhs.key2id_;
id2key_ = rhs.id2key_;
next_id_ = rhs.next_id_;
return *this;
}

size_t size() const {
return key2id_.size();
Expand Down
96 changes: 75 additions & 21 deletions jubatus/core/driver/classifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

#include "classifier.hpp"

Expand All @@ -24,13 +24,17 @@
#include "../classifier/classifier_factory.hpp"
#include "../classifier/classifier_base.hpp"
#include "../common/vector_util.hpp"
#include "../common/byte_buffer.hpp"
#include "../framework/stream_writer.hpp"
#include "../framework/packer.hpp"
#include "../fv_converter/datum.hpp"
#include "../fv_converter/datum_to_fv_converter.hpp"
#include "../fv_converter/converter_config.hpp"
#include "../storage/storage_factory.hpp"

using std::string;
using std::vector;
using std::make_pair;
using jubatus::util::lang::shared_ptr;
using jubatus::core::fv_converter::weight_manager;
using jubatus::core::fv_converter::mixable_weight_manager;
Expand All @@ -55,40 +59,40 @@ classifier::~classifier() {
}

void classifier::train(const string& label, const fv_converter::datum& data) {
common::sfv_t v;
converter_->convert_and_update_weight(data, v);
common::sort_and_merge(v);
classifier_->train(v, label);
common::sfv_t v;
converter_->convert_and_update_weight(data, v);
common::sort_and_merge(v);
classifier_->train(v, label);
}

jubatus::core::classifier::classify_result classifier::classify(
const fv_converter::datum& data) const {
common::sfv_t v;
converter_->convert(data, v);
const fv_converter::datum& data) const {
common::sfv_t v;
converter_->convert(data, v);

jubatus::core::classifier::classify_result scores;
classifier_->classify_with_scores(v, scores);
return scores;
jubatus::core::classifier::classify_result scores;
classifier_->classify_with_scores(v, scores);
return scores;
}

void classifier::get_status(std::map<string, string>& status) const {
classifier_->get_status(status);
classifier_->get_status(status);
}

bool classifier::delete_label(const std::string& label) {
return classifier_->delete_label(label);
return classifier_->delete_label(label);
}

void classifier::clear() {
classifier_->clear();
converter_->clear_weights();
classifier_->clear();
converter_->clear_weights();
}

std::vector<std::string> classifier::get_labels() const {
return classifier_->get_labels();
return classifier_->get_labels();
}
bool classifier::set_label(const std::string& label) {
return classifier_->set_label(label);
return classifier_->set_label(label);
}

void classifier::pack(framework::packer& pk) const {
Expand All @@ -109,6 +113,56 @@ void classifier::unpack(msgpack::object o) {
wm_.get_model()->unpack(o.via.array.ptr[1]);
}

} // namespace driver
} // namespace core
} // namespace jubatus
struct versioned_model {
// this struct is version compatible
std::string version;
common::byte_buffer buffer;
versioned_model(const std::string& v, const common::byte_buffer& b)
: version(v), buffer(b) {
}
MSGPACK_DEFINE(version, buffer);
};

void classifier::import_model(common::byte_buffer& from) const {
msgpack::unpacked packed;
msgpack::unpack(&packed, from.ptr(), from.size());
msgpack::object o = packed.get();
if (o.type != msgpack::type::ARRAY || o.via.array.size != 2) {
throw msgpack::type_error();
}
const std::string& version = o.via.array.ptr[0].as<std::string>();
common::byte_buffer serialized_model =
o.via.array.ptr[1].as<common::byte_buffer>();

msgpack::unpacked unpacked_model;
msgpack::unpack(&unpacked_model, serialized_model.ptr(), serialized_model.size());
msgpack::object model = unpacked_model.get();

if (version == "0.0.1") {
} else {
throw JUBATUS_EXCEPTION(
common::invalid_parameter("unknown version number: " + version));
}
}

common::byte_buffer classifier::export_model() const {
common::byte_buffer model;
framework::stream_writer<common::byte_buffer> model_writer(model);
core::framework::jubatus_packer jp(model_writer);
core::framework::packer pk(jp);
// msgpack::packer<common::byte_buffer> pk(model);
classifier_->pack(pk);
wm_.get_model()->pack(pk);
{
common::byte_buffer ret;
msgpack::pack(ret,
versioned_model(JUBATUS_CORE_VERSION,
model));
return ret;
}
}


} // namespace driver
} // namespace core
} // namespace jubatus
4 changes: 4 additions & 0 deletions jubatus/core/driver/classifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <string>
#include <vector>
#include "jubatus/util/lang/shared_ptr.h"
#include "../common/byte_buffer.hpp"
#include "../classifier/classifier_type.hpp"
#include "../framework/mixable.hpp"
#include "../fv_converter/mixable_weight_manager.hpp"
Expand Down Expand Up @@ -60,6 +61,9 @@ class classifier : public driver_base {
void pack(framework::packer& pk) const;
void unpack(msgpack::object o);

void import_model(common::byte_buffer& from) const;
common::byte_buffer export_model() const;

std::vector<std::string> get_labels() const;
bool set_label(const std::string& label);

Expand Down
6 changes: 6 additions & 0 deletions jubatus/core/driver/recommender.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ void recommender::unpack(msgpack::object o) {
wm_.get_model()->unpack(o.via.array.ptr[1]);
}

void recommender::import_model(common::byte_buffer& from) const {

}
common::byte_buffer recommender::export_model() const {
}

} // namespace driver
} // namespace core
} // namespace jubatus
3 changes: 3 additions & 0 deletions jubatus/core/driver/recommender.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ class recommender : public driver_base {
fv_converter::datum decode_row(const std::string& id);
std::vector<std::string> get_all_rows();

void import_model(common::byte_buffer& from) const;
common::byte_buffer export_model() const;

private:
jubatus::util::lang::shared_ptr<fv_converter::datum_to_fv_converter>
converter_;
Expand Down
3 changes: 3 additions & 0 deletions jubatus/core/framework/linear_function_mixer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class linear_function_mixer : public linear_mixable {
model_ptr get_model() const {
return model_;
}
void set_model(model_ptr model) {
model_ = model;
}

void mix(const diffv& lhs, diffv& mixed) const;
void get_diff(diffv&) const;
Expand Down
1 change: 1 addition & 0 deletions jubatus/core/fv_converter/datum_to_fv_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ void datum_to_fv_converter::clear_weights() {
pimpl_->clear_weights();
}


} // namespace fv_converter
} // namespace core
} // namespace jubatus
1 change: 1 addition & 0 deletions jubatus/core/fv_converter/datum_to_fv_converter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "jubatus/util/lang/scoped_ptr.h"
#include "../common/type.hpp"
#include "../framework/mixable.hpp"
#include "../framework/packer.hpp"

namespace jubatus {
namespace core {
Expand Down
Loading

0 comments on commit 3c1a09c

Please sign in to comment.