From 2e6dbfe0a8ea687486c26b175ab7b71bf57b7f7b Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Mon, 10 Aug 2015 20:21:20 +0200 Subject: [PATCH] simplified words processing in text input connector + options for min_count and min_word_length, ref #13 --- src/caffeinputconns.h | 2 +- src/txtinputfileconn.cc | 61 +++++++++++++++++++++++++++++------------ src/txtinputfileconn.h | 30 ++++++++++---------- 3 files changed, 61 insertions(+), 32 deletions(-) diff --git a/src/caffeinputconns.h b/src/caffeinputconns.h index f52cb96c8..3b6bacda5 100644 --- a/src/caffeinputconns.h +++ b/src/caffeinputconns.h @@ -596,7 +596,7 @@ namespace dd auto hit = tbe._v.cbegin(); while(hit!=tbe._v.cend()) { - datum.set_float_data((*hit).first,static_cast((*hit).second)); + datum.set_float_data(_vocab[(*hit).first]._pos,static_cast((*hit).second)); ++hit; } return datum; diff --git a/src/txtinputfileconn.cc b/src/txtinputfileconn.cc index a4ee660a3..e7aa4250b 100644 --- a/src/txtinputfileconn.cc +++ b/src/txtinputfileconn.cc @@ -110,20 +110,51 @@ namespace dd } // post-processing - if (_ctfc->_tfidf) + size_t initial_vocab_size = _ctfc->_vocab.size(); + auto vhit = _ctfc->_vocab.begin(); + while(vhit!=_ctfc->_vocab.end()) { - //std::unordered_map::const_iterator vhit; - //std::unordered_map::const_iterator rvhit; + if ((*vhit).second._total_count < _ctfc->_min_count) + vhit = _ctfc->_vocab.erase(vhit); + else ++vhit; + } + if (initial_vocab_size != _ctfc->_vocab.size()) + { + // update pos + int pos = 0; + vhit = _ctfc->_vocab.begin(); + while(vhit!=_ctfc->_vocab.end()) + { + (*vhit).second._pos = pos; + ++pos; + ++vhit; + } + } + + if (initial_vocab_size != _ctfc->_vocab.size() || _ctfc->_tfidf) + { + // clearing up the corpus + tfidf + std::unordered_map::iterator whit; for (TxtBowEntry &tbe: _ctfc->_txt) { auto hit = tbe._v.begin(); while(hit!=tbe._v.end()) { - std::string ws = _ctfc->_rvocab[(*hit).first]; - Word w = _ctfc->_vocab[ws]; - (*hit).second = (std::log(1.0+(*hit).second / static_cast(w._total_count))) * std::log(_ctfc->_txt.size() / static_cast(w._total_docs) + 1.0); - //std::cerr << "tfidf feature w=" << ws << " / val=" << (*hit).second << std::endl; - ++hit; + if ((whit=_ctfc->_vocab.find((*hit).first))!=_ctfc->_vocab.end()) + { + if (_ctfc->_tfidf) + { + Word w = (*whit).second; + (*hit).second = (std::log(1.0+(*hit).second / static_cast(w._total_count))) * std::log(_ctfc->_txt.size() / static_cast(w._total_docs) + 1.0); + //std::cerr << "tfidf feature w=" << (*hit).first << " / val=" << (*hit).second << std::endl; + } + ++hit; + } + else + { + //std::cerr << "removing ws=" << (*hit).first << std::endl; + hit = tbe._v.erase(hit); + } } } } @@ -139,8 +170,7 @@ namespace dd correspf.close(); LOG(INFO) << "vocabulary size=" << _ctfc->_vocab.size() << std::endl; - //_ctfc->_vocab.clear(); //TODO: serialize to disk / db - + return 0; } @@ -160,7 +190,8 @@ namespace dd boost::tokenizer> tokens(ct,sep); for (std::string w : tokens) { - //std::cout << w << std::endl; + if (static_cast(w.length()) < _min_word_length) + continue; // check and fillup vocab. int pos = -1; @@ -170,22 +201,18 @@ namespace dd { pos = _vocab.size(); _vocab.emplace(std::make_pair(w,Word(pos))); - if (_tfidf) - _rvocab.insert(std::pair(pos,w)); } } else { - pos = (*vhit).second._pos; if (_train) { (*vhit).second._total_count++; - if (!tbe.has_word(pos)) + if (!tbe.has_word(w)) (*vhit).second._total_docs++; } } - if (pos >= 0) - tbe.add_word(pos,1.0,_count); + tbe.add_word(w,1.0,_count); } _txt.push_back(tbe); } diff --git a/src/txtinputfileconn.h b/src/txtinputfileconn.h index 43d7bf4b5..e1de49597 100644 --- a/src/txtinputfileconn.h +++ b/src/txtinputfileconn.h @@ -66,25 +66,25 @@ namespace dd TxtBowEntry(const float &target):_target(target) {} ~TxtBowEntry() {} - void add_word(const int &pos, + void add_word(const std::string &str, const double &v, const bool &count) { - std::unordered_map::iterator hit; - if ((hit=_v.find(pos))!=_v.end()) + std::unordered_map::iterator hit; + if ((hit=_v.find(str))!=_v.end()) { if (count) (*hit).second += v; } - else _v.insert(std::pair(pos,v)); + else _v.insert(std::pair(str,v)); } - bool has_word(const int &pos) + bool has_word(const std::string &str) { - return _v.count(pos); + return _v.count(str); } - std::unordered_map _v; /**< words as (). */ + std::unordered_map _v; /**< words as (). */ float _target = -1; /**< class target in training mode. */ }; @@ -92,10 +92,7 @@ namespace dd { public: TxtInputFileConn() - :InputConnectorStrategy() - { - //_vocab = std::unordered_map(1e5); - } + :InputConnectorStrategy() {} ~TxtInputFileConn() {} @@ -114,6 +111,10 @@ namespace dd _count = ad_input.get("count").get(); if (ad_input.has("tfidf")) _tfidf = ad_input.get("tfidf").get(); + if (ad_input.has("min_count")) + _min_count = ad_input.get("min_count").get(); + if (ad_input.has("min_word_length")) + _min_word_length = ad_input.get("min_word_length").get(); } int feature_size() const @@ -148,7 +149,6 @@ namespace dd if (ad.has("model_repo")) _model_repo = ad.get("model_repo").get(); - std::cerr << "train=" << _train << " / vocab size=" << _vocab.size() << std::endl; if (!_train && _vocab.empty()) deserialize_vocab(); @@ -185,6 +185,7 @@ namespace dd } _txt.erase(dchit,_txt.end()); std::cout << "data split test size=" << _test_txt.size() << " / remaining data size=" << _txt.size() << std::endl; + std::cout << "vocab size=" << _vocab.size() << std::endl; } if (_txt.empty()) throw InputConnectorBadParamException("no text could be found"); @@ -205,10 +206,11 @@ namespace dd double _test_split = 0.0; bool _count = true; /**< whether to add up word counters */ bool _tfidf = false; /**< whether to use TF/IDF */ + int _min_count = 5; /**< min word occurence. */ + int _min_word_length = 5; /**< min word length. */ // internals - std::unordered_map _vocab; /**< string to word stats, including pos */ - std::unordered_map _rvocab; /**< pos to string */ + std::unordered_map _vocab; /**< string to word stats, including word */ std::string _vocabfname = "vocab.dat"; std::string _model_repo; std::string _correspname = "corresp.txt";