Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
KUMAZAKI Hiroki committed Mar 19, 2015
1 parent a424712 commit 39a3fb2
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 21 deletions.
4 changes: 2 additions & 2 deletions jubatus/core/classifier/classifier_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class classifier_base {
virtual void set_label_unlearner(
jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
label_unlearner) = 0;
virtual jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
get_label_unlearner() const = 0;

virtual bool delete_label(const std::string& label) = 0;
virtual std::vector<std::string> get_labels() const = 0;
Expand All @@ -59,8 +61,6 @@ class classifier_base {

virtual void pack(framework::packer& pk) const = 0;
virtual void unpack(msgpack::object o) = 0;
virtual void export_model(framework::packer& pk) const = 0;
virtual void import_model(msgpack::object o) = 0;
virtual void clear() = 0;

virtual framework::mixable* get_mixable() = 0;
Expand Down
6 changes: 6 additions & 0 deletions jubatus/core/classifier/linear_classifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ void linear_classifier::set_label_unlearner(
unlearner_ = label_unlearner;
}

jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
linear_classifier::get_label_unlearner() const {
return unlearner_;
}


void linear_classifier::classify_with_scores(
const common::sfv_t& sfv,
classify_result& scores) const {
Expand Down
4 changes: 1 addition & 3 deletions jubatus/core/classifier/linear_classifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ class linear_classifier : public classifier_base {
label_unlearner);

jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
label_unlearner() const {
return unlearner_;
}
get_label_unlearner() const;

std::string classify(const common::sfv_t& fv) const;
void classify_with_scores(const common::sfv_t& fv,
Expand Down
5 changes: 5 additions & 0 deletions jubatus/core/classifier/nearest_neighbor_classifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ void nearest_neighbor_classifier::set_label_unlearner(
unlearner_ = label_unlearner;
}

shared_ptr<unlearner::unlearner_base>
nearest_neighbor_classifier::get_label_unlearner() const {
return unlearner_;
}

std::string nearest_neighbor_classifier::classify(
const common::sfv_t& fv) const {
classify_result result;
Expand Down
2 changes: 2 additions & 0 deletions jubatus/core/classifier/nearest_neighbor_classifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class nearest_neighbor_classifier : public classifier_base {
void set_label_unlearner(
jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
label_unlearner);
jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
get_label_unlearner() const;

std::string classify(const common::sfv_t& fv) const;
void classify_with_scores(const common::sfv_t& fv,
Expand Down
8 changes: 8 additions & 0 deletions jubatus/core/common/export_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,13 @@
void export_model(framework::packer& pk) const { \
msgpack::type::make_define(__VA_ARGS__).msgpack_pack(pk); \
}
#define JUBATUS_IMPORT_MODEL(...) \
void import_model(msgpack::object o) { \
this->clear(); \
msgpack::type::make_define(__VA_ARGS__).msgpack_unpack(o); \
}

#define JUBATUS_PORTING_MODEL(...) \
JUBATUS_IMPORT_MODEL(__VA_ARGS__) \
JUBATUS_EXPORT_MODEL(__VA_ARGS__)
#endif // JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP_
6 changes: 6 additions & 0 deletions jubatus/core/common/key_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ class key_manager {
key2id_.swap(km.key2id_);
id2key_.swap(km.id2key_);
}
key_manager& operator=(const key_manager& rhs) {
key2id_ = rhs.key2id_;
id2key_ = rhs.id2key_;
next_id_ = rhs.next_id_;
return *this;
}

size_t size() const {
return key2id_.size();
Expand Down
24 changes: 21 additions & 3 deletions jubatus/core/driver/classifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,25 @@ struct versioned_model {
};

void classifier::import_model(common::byte_buffer& from) const {
// not implemented yet
msgpack::unpacked packed;
msgpack::unpack(&packed, from.ptr(), from.size());
msgpack::object o = packed.get();
if (o.type != msgpack::type::ARRAY || o.via.array.size != 2) {
throw msgpack::type_error();
}
const std::string& version = o.via.array.ptr[0].as<std::string>();
common::byte_buffer serialized_model =
o.via.array.ptr[1].as<common::byte_buffer>();

msgpack::unpacked unpacked_model;
msgpack::unpack(&unpacked_model, serialized_model.ptr(), serialized_model.size());
msgpack::object model = unpacked_model.get();

if (version == "0.0.1") {
} else {
throw JUBATUS_EXCEPTION(
common::invalid_parameter("unknown version number: " + version));
}
}

common::byte_buffer classifier::export_model() const {
Expand All @@ -133,8 +151,8 @@ common::byte_buffer classifier::export_model() const {
core::framework::jubatus_packer jp(model_writer);
core::framework::packer pk(jp);
// msgpack::packer<common::byte_buffer> pk(model);
classifier_->export_model(pk);
wm_->export_model(pk);
classifier_->pack(pk);
wm_.get_model()->pack(pk);
{
common::byte_buffer ret;
msgpack::pack(ret,
Expand Down
3 changes: 3 additions & 0 deletions jubatus/core/framework/linear_function_mixer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class linear_function_mixer : public linear_mixable {
model_ptr get_model() const {
return model_;
}
void set_model(model_ptr model) {
model_ = model;
}

void mix(const diffv& lhs, diffv& mixed) const;
void get_diff(diffv&) const;
Expand Down
8 changes: 0 additions & 8 deletions jubatus/core/fv_converter/datum_to_fv_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,6 @@ class datum_to_fv_converter_impl {
}
}

void export_model(framework::packer& pk) {

}

private:
void filter_strings(
const datum::sv_t& string_values,
Expand Down Expand Up @@ -662,10 +658,6 @@ void datum_to_fv_converter::clear_weights() {
pimpl_->clear_weights();
}

void datum_to_fv_converter::export_model(framework::packer& pk) {
pimpl_->export_model(pk);
}


} // namespace fv_converter
} // namespace core
Expand Down
3 changes: 0 additions & 3 deletions jubatus/core/fv_converter/datum_to_fv_converter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,6 @@ class datum_to_fv_converter {
void set_weight_manager(jubatus::util::lang::shared_ptr<weight_manager> wm);
void clear_weights();

void import_model(framework::packer& pk);
void export_model(framework::packer& pk);

private:
jubatus::util::lang::scoped_ptr<datum_to_fv_converter_impl> pimpl_;
};
Expand Down
3 changes: 1 addition & 2 deletions jubatus/core/storage/local_storage_mixture.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ class local_storage_mixture : public storage_base {
std::string type() const;

MSGPACK_DEFINE(tbl_, class2id_, tbl_diff_, model_version_);
JUBATUS_EXPORT_MODEL(tbl_, class2id_, model_version_);
void import_model(msgpack::object o);
JUBATUS_PORTING_MODEL(tbl_, class2id_, model_version_);

private:
bool get_internal(const std::string& feature, id_feature_val3_t& ret) const;
Expand Down

0 comments on commit 39a3fb2

Please sign in to comment.