Skip to content

Commit

Permalink
move things into src/reader/
Browse files Browse the repository at this point in the history
  • Loading branch information
mli committed Feb 4, 2016
1 parent b59ba83 commit 58d525a
Show file tree
Hide file tree
Showing 21 changed files with 145 additions and 190 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ bcd/bcd_learner.o \
store/store.o \
tracker/tracker.o \
reporter/reporter.o \
data/localizer.o data/batch_iter.o )
data/localizer.o reader/batch_reader.o )

DMLC_DEPS = dmlc-core/libdmlc.a

Expand Down
13 changes: 7 additions & 6 deletions src/bcd/bcd_learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ void BCDLearner::RunScheduler() {
void BCDLearner::PrepareData(const bcd::JobArgs& job,
bcd::PrepDataRets* rets) {
// read train data
ChunkIter train(param_.data_in, param_.data_format,
model_store_->Rank(), model_store_->NumWorkers(),
param_.data_chunk_size);

Reader train(param_.data_in, param_.data_format,
model_store_->Rank(), model_store_->NumWorkers(),
param_.data_chunk_size);
bcd::FeaGroupStats stats(param_.num_feature_group_bits);
tile_builder_ = new bcd::TileBuilder(tile_store_, DEFAULT_NTHREADS);
while (train.Next()) {
Expand All @@ -129,9 +130,9 @@ void BCDLearner::PrepareData(const bcd::JobArgs& job,

// read validation data if any
if (param_.data_val.size()) {
ChunkIter val(param_.data_val, param_.data_format,
model_store_->Rank(), model_store_->NumWorkers(),
param_.data_chunk_size);
Reader val(param_.data_val, param_.data_format,
model_store_->Rank(), model_store_->NumWorkers(),
param_.data_chunk_size);
while (val.Next()) {
auto rowblk = val.Value();
tile_builder_->Add(rowblk, false);
Expand Down
2 changes: 1 addition & 1 deletion src/bcd/bcd_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "difacto/learner.h"
#include "difacto/node_id.h"
#include "dmlc/data.h"
#include "data/chunk_iter.h"
#include "reader/reader.h"
#include "data/data_store.h"
#include "./bcd_param.h"
#include "./bcd_job.h"
Expand Down
19 changes: 18 additions & 1 deletion src/bcd/tile_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,27 @@ namespace difacto {
namespace bcd {

/**
* \brief a row and column block X, stored in column-major order, namely X'
* \brief a sliced block of a large matrix
*
* assume the we evenly partition the following 4x4 matrix into 4 2x2 tiles
* \code
* 1..4
* ....
* .3..
* 2..1
* \endcode
*
* then the tile at row=2 and col=2 is
*
* \code
* ..
* .1
* \endcode
*/
struct Tile {
/** \brief the map to the column id on the original matrix */
SArray<int> colmap;
/** \brief the transposed data to make slice efficient */
SharedRowBlockContainer<unsigned> data;
};

Expand Down
4 changes: 2 additions & 2 deletions src/data/converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#ifndef DIFACTO_DATA_CONVERTER_H_
#define DIFACTO_DATA__CONVERTER_H_
#include "dmlc/parameter.h"
#include "data/chunk_iter.h"
#include "reader/reader.h"
#include "dmlc/io.h"
namespace difacto {

Expand Down Expand Up @@ -44,7 +44,7 @@ class Converter {
void Run() {
using namespace dmlc;
using namespace dmlc::data;
ChunkIter in(param_.data_in, param_.data_format, 0, 1, 8);
Reader in(param_.data_in, param_.data_format, 0, 1, 8);

LOG(INFO) << "reading data from " << param_.data_in
<< " in " << param_.data_format << " format";
Expand Down
File renamed without changes.
42 changes: 9 additions & 33 deletions src/data/batch_iter.cc → src/reader/batch_reader.cc
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
/**
* Copyright (c) 2015 by Contributors
*/
#include "./batch_iter.h"
#include "data/libsvm_parser.h"
#include "./adfea_parser.h"
#include "./crb_parser.h"
#include "./criteo_parser.h"
#include "./batch_reader.h"
namespace difacto {

BatchIter::BatchIter(
BatchReader::BatchReader(
const std::string& uri, const std::string& format,
unsigned part_index, unsigned num_parts,
unsigned batch_size, unsigned shuffle_buf_size,
Expand All @@ -21,43 +17,23 @@ BatchIter::BatchIter(
seed_ = 0;
if (shuf_buf_) {
CHECK_GE(shuf_buf_, batch_size_);
buf_reader_ = new BatchIter(
buf_reader_ = new BatchReader(
uri, format, part_index, num_parts, shuf_buf_);
parser_ = NULL;
reader_ = NULL;
} else {
buf_reader_ = NULL;
// create parser
char const* c_uri = uri.c_str();
if (format == "libsvm") {
parser_ = new dmlc::data::LibSVMParser<feaid_t>(
dmlc::InputSplit::Create(c_uri, part_index, num_parts, "text"), 1);
} else if (format == "criteo") {
parser_ = new CriteoParser(
dmlc::InputSplit::Create(c_uri, part_index, num_parts, "text"), true);
} else if (format == "criteo_test") {
parser_ = new CriteoParser(
dmlc::InputSplit::Create(c_uri, part_index, num_parts, "text"), false);
} else if (format == "adfea") {
parser_ = new AdfeaParser(
dmlc::InputSplit::Create(c_uri, part_index, num_parts, "text"));
} else if (format == "rec") {
parser_ = new CRBParser(
dmlc::InputSplit::Create(c_uri, part_index, num_parts, "recordio"));
} else {
LOG(FATAL) << "unknown format " << format;
}
parser_ = new dmlc::data::ThreadedParser<feaid_t>(parser_);
reader_ = new Reader(uri, format, part_index, num_parts, 1<<26);
}
}

bool BatchIter::Next() {
bool BatchReader::Next() {
batch_.Clear();
while (batch_.offset.size() < batch_size_ + 1) {
if (start_ == end_) {
if (shuf_buf_ == 0) {
// no random shuffle
if (!parser_->Next()) break;
in_blk_ = parser_->Value();
if (!reader_->Next()) break;
in_blk_ = reader_->Value();
} else {
// do random shuffle
if (!buf_reader_->Next()) break;
Expand Down Expand Up @@ -101,7 +77,7 @@ bool BatchIter::Next() {
return out_blk_.size > 0;
}

void BatchIter::Push(size_t pos, size_t len) {
void BatchReader::Push(size_t pos, size_t len) {
if (!len) return;
CHECK_LE(pos + len, in_blk_.size);
dmlc::RowBlock<feaid_t> slice;
Expand Down
31 changes: 12 additions & 19 deletions src/data/batch_iter.h → src/reader/batch_reader.h
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
/**
* Copyright (c) 2015 by Contributors
*/
#ifndef DIFACTO_DATA_BATCH_ITER_H_
#define DIFACTO_DATA_BATCH_ITER_H_
#ifndef DIFACTO_DATA_BATCH_READER_H_
#define DIFACTO_DATA_BATCH_READER_H_
#include <string>
#include <vector>
#include "difacto/base.h"
#include "dmlc/data.h"
#include "data/parser.h"
#include "./reader.h"
namespace difacto {

/**
* \brief an iterator reads a batch with a given number of examples
* \brief a reader reads a batch with a given number of examples
* each time.
*/
class BatchIter {
class BatchReader {
public:
/**
* \brief create a batch iterator
Expand All @@ -28,16 +28,16 @@ class BatchIter {
* shuffle_buf_size examples
* @param neg_sampling the probability to pickup a negative sample (label <= 0)
*/
BatchIter(const std::string& uri,
BatchReader(const std::string& uri,
const std::string& format,
unsigned part_index,
unsigned num_parts,
unsigned batch_size,
unsigned shuffle_buf_size = 0,
float neg_sampling = 1.0);

~BatchIter() {
delete parser_;
~BatchReader() {
delete reader_;
delete buf_reader_;
}

Expand All @@ -54,22 +54,16 @@ class BatchIter {
return out_blk_;
}

/**
* \brief reset to the file beginning
*/
void Reset() {
if (parser_) parser_->BeforeFirst();
if (buf_reader_) buf_reader_->Reset();
}

private:
/**
* \brief batch_.push(in_blk_(pos:pos+len))
*/
void Push(size_t pos, size_t len);

unsigned batch_size_, shuf_buf_;
dmlc::data::ParserImpl<feaid_t> *parser_;

Reader *reader_;
BatchReader* buf_reader_;

float neg_sampling_;
size_t start_, end_;
Expand All @@ -78,9 +72,8 @@ class BatchIter {

// random pertubation
std::vector<unsigned> rdp_;
BatchIter* buf_reader_;
unsigned int seed_;
};

} // namespace difacto
#endif // DIFACTO_DATA_BATCH_ITER_H_
#endif // DIFACTO_DATA_BATCH_READER_H_
2 changes: 1 addition & 1 deletion src/data/crb_parser.h → src/reader/crb_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <vector>
#include "data/parser.h"
#include "dmlc/recordio.h"
#include "./compressed_row_block.h"
#include "data/compressed_row_block.h"
namespace difacto {
/**
* \brief compressed row block parser
Expand Down
File renamed without changes.
47 changes: 16 additions & 31 deletions src/data/chunk_iter.h → src/reader/reader.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
/**
* Copyright (c) 2015 by Contributors
*/
#ifndef DIFACTO_DATA_CHUNK_ITER_H_
#define DIFACTO_DATA_CHUNK_ITER_H_
#ifndef DIFACTO_READER_READER_H_
#define DIFACTO_READER_READER_H_
#include <string>
#include <vector>
#include "difacto/base.h"
#include "dmlc/data.h"
#include "data/parser.h"
Expand All @@ -14,28 +13,19 @@
#include "./criteo_parser.h"
namespace difacto {
/**
* \brief an iterator reads a chunk with a hint data size
* \brief a reader reads a chunk of data with roughly same size a time
*/
class ChunkIter {
class Reader {
public:
/**
* \brief create a chunk iterator
*
* @param uri filename
* @param format the data format, support libsvm, crb, ...
* @param part_index the i-th part to read
* @param num_parts partition the file into serveral parts
* @param chunk_size the chunk size.
*/
ChunkIter(const std::string& uri,
const std::string& format,
unsigned part_index,
unsigned num_parts,
unsigned chunk_size) {
Reader(const std::string& uri,
const std::string& format,
int part_index,
int num_parts,
int chunk_size_hint) {
char const* c_uri = uri.c_str();
dmlc::InputSplit* input = dmlc::InputSplit::Create(
c_uri, part_index, num_parts, format == "cb" ? "recordio" : "text");
input->HintChunkSize(chunk_size);
input->HintChunkSize(chunk_size_hint);

if (format == "libsvm") {
parser_ = new dmlc::data::LibSVMParser<feaid_t>(input, 1);
Expand All @@ -53,20 +43,15 @@ class ChunkIter {
parser_ = new dmlc::data::ThreadedParser<feaid_t>(parser_);
}

~ChunkIter() {
delete parser_;
}
~Reader() { delete parser_; }

bool Next() {
return parser_->Next();
}
bool Next() { return parser_->Next(); }

const dmlc::RowBlock<feaid_t>& Value() const { return parser_->Value(); }

const dmlc::RowBlock<feaid_t>& Value() const {
return parser_->Value();
}
private:
dmlc::data::ParserImpl<feaid_t> *parser_;
dmlc::data::ParserImpl<feaid_t>* parser_;
};

} // namespace difacto
#endif // DIFACTO_DATA_CHUNK_ITER_H_
#endif // DIFACTO_READER_READER_H_
4 changes: 2 additions & 2 deletions src/sgd/sgd_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <thread>
#include "dmlc/data.h"
#include "difacto/learner.h"
#include "data/batch_iter.h"
#include "reader/batch_reader.h"
#include "data/row_block.h"
#include "data/localizer.h"
#include "dmlc/timer.h"
Expand Down Expand Up @@ -185,7 +185,7 @@ class SGDLearner : public Learner {
int batch_size = 100;
int shuffle = 0;
float neg_sampling = 1;
BatchIter reader(
BatchReader reader(
job.filename, param_.data_format, job.part_idx, job.num_parts,
batch_size, shuffle, neg_sampling);
while (reader.Next()) {
Expand Down
Loading

0 comments on commit 58d525a

Please sign in to comment.