Skip to content

Commit

Permalink
simplified words processing in text input connector + options for min…
Browse files Browse the repository at this point in the history
…_count and min_word_length, ref jolibrain#13
  • Loading branch information
beniz committed Aug 10, 2015
1 parent 61f4e8c commit 2e6dbfe
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/caffeinputconns.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ namespace dd
auto hit = tbe._v.cbegin();
while(hit!=tbe._v.cend())
{
datum.set_float_data((*hit).first,static_cast<float>((*hit).second));
datum.set_float_data(_vocab[(*hit).first]._pos,static_cast<float>((*hit).second));
++hit;
}
return datum;
Expand Down
61 changes: 44 additions & 17 deletions src/txtinputfileconn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,51 @@ namespace dd
}

// post-processing
if (_ctfc->_tfidf)
size_t initial_vocab_size = _ctfc->_vocab.size();
auto vhit = _ctfc->_vocab.begin();
while(vhit!=_ctfc->_vocab.end())
{
//std::unordered_map<std::string,Word>::const_iterator vhit;
//std::unordered_map<int,std::string>::const_iterator rvhit;
if ((*vhit).second._total_count < _ctfc->_min_count)
vhit = _ctfc->_vocab.erase(vhit);
else ++vhit;
}
if (initial_vocab_size != _ctfc->_vocab.size())
{
// update pos
int pos = 0;
vhit = _ctfc->_vocab.begin();
while(vhit!=_ctfc->_vocab.end())
{
(*vhit).second._pos = pos;
++pos;
++vhit;
}
}

if (initial_vocab_size != _ctfc->_vocab.size() || _ctfc->_tfidf)
{
// clearing up the corpus + tfidf
std::unordered_map<std::string,Word>::iterator whit;
for (TxtBowEntry &tbe: _ctfc->_txt)
{
auto hit = tbe._v.begin();
while(hit!=tbe._v.end())
{
std::string ws = _ctfc->_rvocab[(*hit).first];
Word w = _ctfc->_vocab[ws];
(*hit).second = (std::log(1.0+(*hit).second / static_cast<double>(w._total_count))) * std::log(_ctfc->_txt.size() / static_cast<double>(w._total_docs) + 1.0);
//std::cerr << "tfidf feature w=" << ws << " / val=" << (*hit).second << std::endl;
++hit;
if ((whit=_ctfc->_vocab.find((*hit).first))!=_ctfc->_vocab.end())
{
if (_ctfc->_tfidf)
{
Word w = (*whit).second;
(*hit).second = (std::log(1.0+(*hit).second / static_cast<double>(w._total_count))) * std::log(_ctfc->_txt.size() / static_cast<double>(w._total_docs) + 1.0);
//std::cerr << "tfidf feature w=" << (*hit).first << " / val=" << (*hit).second << std::endl;
}
++hit;
}
else
{
//std::cerr << "removing ws=" << (*hit).first << std::endl;
hit = tbe._v.erase(hit);
}
}
}
}
Expand All @@ -139,8 +170,7 @@ namespace dd
correspf.close();

LOG(INFO) << "vocabulary size=" << _ctfc->_vocab.size() << std::endl;
//_ctfc->_vocab.clear(); //TODO: serialize to disk / db


return 0;
}

Expand All @@ -160,7 +190,8 @@ namespace dd
boost::tokenizer<boost::char_separator<char>> tokens(ct,sep);
for (std::string w : tokens)
{
//std::cout << w << std::endl;
if (static_cast<int>(w.length()) < _min_word_length)
continue;

// check and fillup vocab.
int pos = -1;
Expand All @@ -170,22 +201,18 @@ namespace dd
{
pos = _vocab.size();
_vocab.emplace(std::make_pair(w,Word(pos)));
if (_tfidf)
_rvocab.insert(std::pair<int,std::string>(pos,w));
}
}
else
{
pos = (*vhit).second._pos;
if (_train)
{
(*vhit).second._total_count++;
if (!tbe.has_word(pos))
if (!tbe.has_word(w))
(*vhit).second._total_docs++;
}
}
if (pos >= 0)
tbe.add_word(pos,1.0,_count);
tbe.add_word(w,1.0,_count);
}
_txt.push_back(tbe);
}
Expand Down
30 changes: 16 additions & 14 deletions src/txtinputfileconn.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,36 +66,33 @@ namespace dd
TxtBowEntry(const float &target):_target(target) {}
~TxtBowEntry() {}

void add_word(const int &pos,
void add_word(const std::string &str,
const double &v,
const bool &count)
{
std::unordered_map<int,double>::iterator hit;
if ((hit=_v.find(pos))!=_v.end())
std::unordered_map<std::string,double>::iterator hit;
if ((hit=_v.find(str))!=_v.end())
{
if (count)
(*hit).second += v;
}
else _v.insert(std::pair<int,double>(pos,v));
else _v.insert(std::pair<std::string,double>(str,v));
}

bool has_word(const int &pos)
bool has_word(const std::string &str)
{
return _v.count(pos);
return _v.count(str);
}

std::unordered_map<int,double> _v; /**< words as (<pos,val>). */
std::unordered_map<std::string,double> _v; /**< words as (<pos,val>). */
float _target = -1; /**< class target in training mode. */
};

class TxtInputFileConn : public InputConnectorStrategy
{
public:
TxtInputFileConn()
:InputConnectorStrategy()
{
//_vocab = std::unordered_map<std::string,Word>(1e5);
}
:InputConnectorStrategy() {}

~TxtInputFileConn() {}

Expand All @@ -114,6 +111,10 @@ namespace dd
_count = ad_input.get("count").get<bool>();
if (ad_input.has("tfidf"))
_tfidf = ad_input.get("tfidf").get<bool>();
if (ad_input.has("min_count"))
_min_count = ad_input.get("min_count").get<int>();
if (ad_input.has("min_word_length"))
_min_word_length = ad_input.get("min_word_length").get<int>();
}

int feature_size() const
Expand Down Expand Up @@ -148,7 +149,6 @@ namespace dd
if (ad.has("model_repo"))
_model_repo = ad.get("model_repo").get<std::string>();

std::cerr << "train=" << _train << " / vocab size=" << _vocab.size() << std::endl;
if (!_train && _vocab.empty())
deserialize_vocab();

Expand Down Expand Up @@ -185,6 +185,7 @@ namespace dd
}
_txt.erase(dchit,_txt.end());
std::cout << "data split test size=" << _test_txt.size() << " / remaining data size=" << _txt.size() << std::endl;
std::cout << "vocab size=" << _vocab.size() << std::endl;
}
if (_txt.empty())
throw InputConnectorBadParamException("no text could be found");
Expand All @@ -205,10 +206,11 @@ namespace dd
double _test_split = 0.0;
bool _count = true; /**< whether to add up word counters */
bool _tfidf = false; /**< whether to use TF/IDF */
int _min_count = 5; /**< min word occurence. */
int _min_word_length = 5; /**< min word length. */

// internals
std::unordered_map<std::string,Word> _vocab; /**< string to word stats, including pos */
std::unordered_map<int,std::string> _rvocab; /**< pos to string */
std::unordered_map<std::string,Word> _vocab; /**< string to word stats, including word */
std::string _vocabfname = "vocab.dat";
std::string _model_repo;
std::string _correspname = "corresp.txt";
Expand Down

0 comments on commit 2e6dbfe

Please sign in to comment.