diff --git a/CMakeLists.txt b/CMakeLists.txt index a8478fc5ea0..2b5ae1faf17 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,6 +5,7 @@ project( Caffe ) option(CPU_ONLY "Build Caffe without GPU support" OFF) option(USE_CUDNN "Build Caffe with cuDNN support" ON) +option(USE_MPI "Build Caffe with MPI support" ON) option(BUILD_PYTHON "Build Python wrapper" ON) option(BUILD_MATLAB "Build Matlab wrapper" OFF) option(BUILD_EXAMPLES "Build examples" ON) diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp index 72efde06228..0afca7cabc4 100644 --- a/include/caffe/blob.hpp +++ b/include/caffe/blob.hpp @@ -128,6 +128,25 @@ class Blob { */ void ShareDiff(const Blob& other); + /** + * @brief Averages the data buffers of all the nodes using + * MPI. Performs an AllReduce followed by an averaging. + */ + void SyncData(); + /** + * @brief Averages the diff buffers of all the nodes using + * MPI. Performs an AllReduce followed by an averaging. + */ + void SyncDiff(); + + /** + * @brief Uses a single rank to broadcast their data to all the + * other nodes. Only the rank specifed keeps their data every other + * rank's data is replaced. Currently just uses cpu due to its usage + * case being in filler. + */ + void BcastData(const int rank = 0); + protected: shared_ptr data_; shared_ptr diff_; diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index 9c6eb4d6834..698de4976b3 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -16,6 +16,7 @@ #include #include "caffe/util/device_alternate.hpp" +#include "caffe/util/mpi.hpp" // gflags 2.1 issue: namespace google was changed to gflags without warning. // Luckily we will be able to use GFLAGS_GFAGS_H_ to detect if it is version @@ -62,9 +63,13 @@ using std::stringstream; using std::vector; // A global initialization function that you should call in your main function. -// Currently it initializes google flags and google logging. +// Currently it initializes google flags, google logging, and MPI. void GlobalInit(int* pargc, char*** pargv); +// A global finalization function that you should call at the end of +// your main function. +void GlobalFinalize(); + // A singleton class to hold common caffe stuff, such as the handler that // caffe is going to use for cublas, curand, etc. class Caffe { @@ -78,7 +83,7 @@ class Caffe { } enum Brew { CPU, GPU }; enum Phase { TRAIN, TEST }; - + enum DeviceState { FIXED, MUTABLE }; // This random number generator facade hides boost and CUDA rng // implementation from one another (for cross-platform compatibility). @@ -112,6 +117,10 @@ class Caffe { inline static Brew mode() { return Get().mode_; } // Returns the phase: TRAIN or TEST. inline static Phase phase() { return Get().phase_; } + // Returns the MPI implementation + inline static shared_ptr mpi() { return Get().mpi_; } + // Returns the DeviceState: MUTABLE or FIXED + inline static DeviceState device_state() { return Get().device_state_; } // The setters for the variables // Sets the mode. It is recommended that you don't change the mode halfway // into the program since that may cause allocation of pinned memory being @@ -120,6 +129,10 @@ class Caffe { inline static void set_mode(Brew mode) { Get().mode_ = mode; } // Sets the phase. inline static void set_phase(Phase phase) { Get().phase_ = phase; } + inline static void set_mpi(shared_ptr mpi) { Get().mpi_ = mpi; } + inline static void set_device_state(DeviceState device_state) { + Get().device_state_ = device_state; + } // Sets the random seed of both boost and curand static void set_random_seed(const unsigned int seed); // Sets the device. Since we have cublas and curand stuff, set device also @@ -127,6 +140,8 @@ class Caffe { static void SetDevice(const int device_id); // Prints the current GPU status. static void DeviceQuery(); + // Returns the number of free bytes on GPU + static size_t DeviceMemoryFree(); protected: #ifndef CPU_ONLY @@ -134,9 +149,11 @@ class Caffe { curandGenerator_t curand_generator_; #endif shared_ptr random_generator_; + shared_ptr mpi_; Brew mode_; Phase phase_; + DeviceState device_state_; static shared_ptr singleton_; private: diff --git a/include/caffe/filler.hpp b/include/caffe/filler.hpp index 136ce958aed..16bbc07697f 100644 --- a/include/caffe/filler.hpp +++ b/include/caffe/filler.hpp @@ -58,6 +58,7 @@ class UniformFiller : public Filler { Dtype(this->filler_param_.max()), blob->mutable_cpu_data()); CHECK_EQ(this->filler_param_.sparse(), -1) << "Sparsity not supported by this Filler."; + blob->BcastData(); } }; @@ -90,6 +91,7 @@ class GaussianFiller : public Filler { data[i] *= mask[i]; } } + blob->BcastData(); } protected: @@ -123,6 +125,7 @@ class PositiveUnitballFiller : public Filler { } CHECK_EQ(this->filler_param_.sparse(), -1) << "Sparsity not supported by this Filler."; + blob->BcastData(); } }; @@ -154,6 +157,7 @@ class XavierFiller : public Filler { blob->mutable_cpu_data()); CHECK_EQ(this->filler_param_.sparse(), -1) << "Sparsity not supported by this Filler."; + blob->BcastData(); } }; diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 6fd159d0b98..969b5da9be4 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -36,6 +36,10 @@ class Solver { // PreSolve is run before any solving iteration starts, allowing one to // put up some scaffold. virtual void PreSolve() {} + // Sync net parameters & any solver data + virtual void SyncData() = 0; + // Sync gradients + virtual void SyncDiff() = 0; // Get the update value for the current iteration. virtual void ComputeUpdateValue() = 0; // The Solver::Snapshot function implements the basic snapshotting utility @@ -80,6 +84,8 @@ class SGDSolver : public Solver { protected: virtual void PreSolve(); Dtype GetLearningRate(); + virtual void SyncDiff(); + virtual void SyncData(); virtual void ComputeUpdateValue(); virtual void SnapshotSolverState(SolverState * state); virtual void RestoreSolverState(const SolverState& state); diff --git a/include/caffe/util/mpi.hpp b/include/caffe/util/mpi.hpp new file mode 100644 index 00000000000..6fecf4ab670 --- /dev/null +++ b/include/caffe/util/mpi.hpp @@ -0,0 +1,77 @@ +#ifndef CAFFE_MPI_HPP_ +#define CAFFE_MPI_HPP_ + +#include "caffe/common.hpp" + +namespace caffe { + + +void caffe_init_mpi(int* pargc, char*** pargv); +void caffe_finalize_mpi(); + + +class MPI { +public: + MPI() : rank_(0), size_(1) {} + virtual inline int rank() const { return rank_; } + virtual inline int size() const { return size_; } + // Can't make a virtual template function :( + virtual void Allreduce(const int count, float *sendrecv_buf) = 0; + virtual void Allreduce(const int count, double *sendrecv_buf) = 0; + virtual void Allreduce(const int count, int *sendrecv_buf) = 0; + virtual void Allreduce(const int count, unsigned int *sendrecv_buf) = 0; + virtual void Bcast(const int count, float *buffer, const int root = 0) = 0; + virtual void Bcast(const int count, double *buffer, const int root = 0) = 0; + virtual void Bcast(const int count, int *buffer, const int root = 0) = 0; + virtual void Bcast(const int count, unsigned int *buffer, const int root = 0) = 0; + + protected: + int rank_, size_; + + // DISABLE_COPY_AND_ASSIGN(MPI); + private: + MPI(const MPI&); + MPI& operator=(const MPI&); +}; + +class MPILocal : public MPI { + public: + MPILocal() : MPI() {} + virtual void Allreduce(const int count, float *sendrecv_buf) {} + virtual void Allreduce(const int count, double *sendrecv_buf) {} + virtual void Allreduce(const int count, int *sendrecv_buf) {} + virtual void Allreduce(const int count, unsigned int *sendrecv_buf) {} + virtual void Bcast(const int count, float *buffer, const int root = 0) {} + virtual void Bcast(const int count, double *buffer, const int root = 0) {} + virtual void Bcast(const int count, int *buffer, const int root = 0) {} + virtual void Bcast(const int count, unsigned int *buffer, const int root = 0) {} + + // DISABLE_COPY_AND_ASSIGN(MPILocal); + private: + MPILocal(const MPILocal&); + MPILocal& operator=(const MPILocal&); +}; + +class MPIDist : public MPI { + public: + MPIDist(); + virtual void Allreduce(const int count, float *sendrecv_buf); + virtual void Allreduce(const int count, double *sendrecv_buf); + virtual void Allreduce(const int count, int *sendrecv_buf); + virtual void Allreduce(const int count, unsigned int *sendrecv_buf); + virtual void Bcast(const int count, float *buffer, const int root = 0); + virtual void Bcast(const int count, double *buffer, const int root = 0); + virtual void Bcast(const int count, int *buffer, const int root = 0); + virtual void Bcast(const int count, unsigned int *buffer, const int root = 0); + + // DISABLE_COPY_AND_ASSIGN(MPIDist); + private: + MPIDist(const MPIDist&); + MPIDist& operator=(const MPIDist&); +}; + + +} // namespace caffe + + +#endif // CAFFE_MPI_HPP_ diff --git a/models/brody/train_val_driving_normalization.prototxt b/models/brody/train_val_driving_normalization.prototxt index b420c0e05cc..2c9a04a23a0 100644 --- a/models/brody/train_val_driving_normalization.prototxt +++ b/models/brody/train_val_driving_normalization.prototxt @@ -9,7 +9,7 @@ layers { data_param { source: "/deep/u/willsong/caffe/driving_train_rgb" backend: LMDB - batch_size: 10 + batch_size: 24 } transform_param { mean_file: "/deep/u/willsong/caffe/driving_mean_rgb.binaryproto" @@ -26,7 +26,7 @@ layers { data_param { source: "/deep/u/willsong/caffe/driving_test_rgb" backend: LMDB - batch_size: 10 + batch_size: 2 } transform_param { mean_file: "/deep/u/willsong/caffe/driving_mean_rgb.binaryproto" diff --git a/models/bvlc_reference_caffenet/train_val.prototxt b/models/bvlc_reference_caffenet/train_val.prototxt index da78e73f44a..073d8aeff4a 100644 --- a/models/bvlc_reference_caffenet/train_val.prototxt +++ b/models/bvlc_reference_caffenet/train_val.prototxt @@ -5,7 +5,7 @@ layers { top: "data" top: "label" data_param { - source: "/deep/u/willsong/data/ilsvrc12_train_lmdb" + source: "examples/imagenet/ilsvrc12_train_lmdb" backend: LMDB batch_size: 256 } @@ -22,7 +22,7 @@ layers { top: "data" top: "label" data_param { - source: "/deep/u/willsong/data/ilsvrc12_val_lmdb" + source: "examples/imagenet/ilsvrc12_val_lmdb" backend: LMDB batch_size: 50 } diff --git a/src/caffe/CMakeLists.txt b/src/caffe/CMakeLists.txt index 0ea38f812cc..a5d1b350964 100644 --- a/src/caffe/CMakeLists.txt +++ b/src/caffe/CMakeLists.txt @@ -58,8 +58,15 @@ include_directories(${LMDB_INCLUDE_DIR}) # Boost find_package(Boost 1.46 COMPONENTS system thread REQUIRED) -include_directories( ${Boost_INCLUDE_DIR} ) -link_directories( ${Boost_LIBRARY_DIRS} ) +include_directories(${Boost_INCLUDE_DIR}) +link_directories(${Boost_LIBRARY_DIRS}) + +# MPI +if (USE_MPI) + add_definitions(-DUSE_MPI) + find_package(MPI REQUIRED) + include_directories(${MPI_INCLUDE_PATH}) +endif() add_subdirectory(proto) @@ -119,6 +126,7 @@ target_link_libraries(caffe proto ${LEVELDB_LIBS} ${LMDB_LIBRARIES} ${OpenCV_LIBS} + ${MPI_LIBRARIES} ) #set output directory diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index ffda477aa4a..0f44a1bd5e2 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -2,6 +2,7 @@ #include "caffe/common.hpp" #include "caffe/syncedmem.hpp" #include "caffe/util/math_functions.hpp" +#include "caffe/util/mpi.hpp" namespace caffe { @@ -103,6 +104,69 @@ void Blob::ShareDiff(const Blob& other) { diff_ = other.diff(); } + +// The "SyncData" method is used for parameter blobs in a Net, which are stored +// as Blob or Blob -- hence we do not define it for +// Blob or Blob. +template <> void Blob::SyncData() { NOT_IMPLEMENTED; } +template <> void Blob::SyncData() { NOT_IMPLEMENTED; } + +template +void Blob::SyncData() { + shared_ptr mpi = Caffe::mpi(); + switch (Caffe::mode()) { + case Caffe::CPU: + mpi->Allreduce(count_, mutable_cpu_data()); + caffe_scal(count_, 1/Dtype(mpi->size()), mutable_cpu_data()); + break; + case Caffe::GPU: +#ifndef CPU_ONLY + mpi->Allreduce(count_, mutable_gpu_data()); + caffe_gpu_scal(count_, 1/Dtype(mpi->size()), mutable_gpu_data()); +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +// The "SyncDiff" method is used for parameter blobs in a Net, which are stored +// as Blob or Blob -- hence we do not define it for +// Blob or Blob. +template <> void Blob::SyncDiff() { NOT_IMPLEMENTED; } +template <> void Blob::SyncDiff() { NOT_IMPLEMENTED; } + +template +void Blob::SyncDiff() { + shared_ptr mpi = Caffe::mpi(); + switch (Caffe::mode()) { + case Caffe::CPU: + mpi->Allreduce(count_, mutable_cpu_diff()); + caffe_scal(count_, 1/Dtype(mpi->size()), mutable_cpu_diff()); + break; + case Caffe::GPU: +#ifndef CPU_ONLY + mpi->Allreduce(count_, mutable_gpu_diff()); + caffe_gpu_scal(count_, 1/Dtype(mpi->size()), mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + + +template +void Blob::BcastData(const int rank) { + shared_ptr mpi = Caffe::mpi(); + mpi->Bcast(count_, mutable_cpu_data(), rank); +} + + // The "update" method is used for parameter blobs in a Net, which are stored // as Blob or Blob -- hence we do not define it for // Blob or Blob. @@ -299,4 +363,3 @@ template class Blob; template class Blob; } // namespace caffe - diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index 94fdf924f44..d0abac5ede6 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -4,6 +4,7 @@ #include "caffe/common.hpp" #include "caffe/util/rng.hpp" +#include "caffe/util/mpi.hpp" namespace caffe { @@ -35,12 +36,19 @@ void GlobalInit(int* pargc, char*** pargv) { ::gflags::ParseCommandLineFlags(pargc, pargv, true); // Google logging. ::google::InitGoogleLogging(*(pargv)[0]); + // Setup MPI before logging in case we change log dirs + caffe_init_mpi(pargc, pargv); +} + +void GlobalFinalize() { + caffe_finalize_mpi(); } #ifdef CPU_ONLY // CPU-only Caffe. Caffe::Caffe() - : random_generator_(), mode_(Caffe::CPU), phase_(Caffe::TRAIN) { } + : random_generator_(), mode_(Caffe::CPU), phase_(Caffe::TRAIN), + device_state_(Caffe::MUTABLE) { } Caffe::~Caffe() { } @@ -84,7 +92,7 @@ void* Caffe::RNG::generator() { Caffe::Caffe() : cublas_handle_(NULL), curand_generator_(NULL), random_generator_(), - mode_(Caffe::CPU), phase_(Caffe::TRAIN) { + mode_(Caffe::CPU), phase_(Caffe::TRAIN), device_state_(Caffe::MUTABLE) { // Try to create a cublas handler, and report an error if failed (but we will // keep the program running as one might just want to run CPU code). if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) { @@ -127,7 +135,7 @@ void Caffe::set_random_seed(const unsigned int seed) { void Caffe::SetDevice(const int device_id) { int current_device; CUDA_CHECK(cudaGetDevice(¤t_device)); - if (current_device == device_id) { + if (current_device == device_id || Caffe::device_state() == Caffe::FIXED) { return; } // The call to cudaSetDevice must come before any calls to Get, which @@ -180,6 +188,18 @@ void Caffe::DeviceQuery() { } +size_t Caffe::DeviceMemoryFree() { + int device; + if (cudaSuccess != cudaGetDevice(&device)) { + printf("No cuda device present.\n"); + return 0; + } + size_t free, total; + CUDA_CHECK(cudaMemGetInfo(&free, &total)); + return free; +} + + class Caffe::RNG::Generator { public: Generator() : rng_(new caffe::rng_t(cluster_seedgen())) {} diff --git a/src/caffe/layers/driving_data_layer.cpp b/src/caffe/layers/driving_data_layer.cpp index ce0ef68ef85..b815e5771d0 100644 --- a/src/caffe/layers/driving_data_layer.cpp +++ b/src/caffe/layers/driving_data_layer.cpp @@ -3,6 +3,7 @@ #include #include +#include #include "caffe/common.hpp" #include "caffe/data_layers.hpp" @@ -147,8 +148,18 @@ void DrivingDataLayer::DataLayerSetUp( CHECK_EQ(mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_), MDB_SUCCESS) << "mdb_cursor_open failed"; LOG(INFO) << "Opening lmdb " << this->layer_param_.data_param().source(); + // Set cursor at begining CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST), - MDB_SUCCESS) << "mdb_cursor_get failed"; + MDB_SUCCESS); + // Jump forward to our MPI rank + for (int i = 0; i < Caffe::mpi()->rank(); ++i) { + CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT), + MDB_SUCCESS); + } + // read our first data point + CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_, + &mdb_value_, MDB_GET_CURRENT), + MDB_SUCCESS) << "mdb_cursor_get failed"; break; default: LOG(FATAL) << "Unknown database backend"; @@ -158,6 +169,7 @@ void DrivingDataLayer::DataLayerSetUp( if (this->layer_param_.data_param().rand_skip()) { unsigned int skip = caffe_rng_rand() % this->layer_param_.data_param().rand_skip(); + Caffe::mpi()->Bcast(1, &skip); LOG(INFO) << "Skipping first " << skip << " data points."; while (skip-- > 0) { switch (this->layer_param_.data_param().backend()) { @@ -211,7 +223,8 @@ void DrivingDataLayer::DataLayerSetUp( (*top)[IMAGE]->Reshape( this->layer_param_.data_param().batch_size(), datum.channels(), data.car_cropped_height(), data.car_cropped_width()); - this->prefetch_datas_[IMAGE]->Reshape(this->layer_param_.data_param().batch_size(), + this->prefetch_datas_[IMAGE]->Reshape( + this->layer_param_.data_param().batch_size(), datum.channels(), data.car_cropped_height(), data.car_cropped_width()); LOG(INFO) << "output image data size: " << (*top)[IMAGE]->num() << "," << (*top)[IMAGE]->channels() << "," << (*top)[IMAGE]->height() << "," @@ -248,7 +261,8 @@ void DrivingDataLayer::DataLayerSetUp( this->datum_channels_ = datum.channels(); this->datum_height_ = data.car_cropped_height(); this->datum_width_ = data.car_cropped_width(); - this->datum_size_ = datum.channels() * data.car_cropped_height() * data.car_cropped_width(); + this->datum_size_ = datum.channels() * data.car_cropped_height() \ + * data.car_cropped_width(); const unsigned int rng_seed = caffe_rng_rand(); rng_.reset(new Caffe::RNG(rng_seed)); @@ -271,14 +285,17 @@ bool DrivingDataLayer::ReadBoundingBoxLabelToDatum( const int full_label_width = width * grid_dim; const int full_label_height = height * grid_dim; const float half_shrink_factor = data.car_shrink_factor() / 2; - const float scaling = static_cast(full_label_width) / data.car_cropped_width(); + const float scaling = static_cast(full_label_width) \ + / data.car_cropped_width(); // 1 pixel label, 4 bounding box coordinates, 3 normalization labels. const int num_total_labels = kNumRegressionMasks; vector labels; for (int i = 0; i < num_total_labels; ++i) { labels.push_back( - new cv::Mat(full_label_height, full_label_width, CV_32F, cv::Scalar(0.0))); + new cv::Mat(full_label_height, + full_label_width, CV_32F, + cv::Scalar(0.0))); } for (int i = 0; i < data.car_boxes_size(); ++i) { @@ -368,7 +385,7 @@ bool DrivingDataLayer::ReadBoundingBoxLabelToDatum( datum->set_channels(num_total_labels); datum->set_height(full_label_height); datum->set_width(full_label_width); - datum->set_label(0); // dummy label + datum->set_label(0); // dummy label datum->clear_data(); datum->clear_float_data(); @@ -425,7 +442,7 @@ void DrivingDataLayer::InternalThreadEntry() { break; case DataParameter_DB_LMDB: CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_, - &mdb_value_, MDB_GET_CURRENT), MDB_SUCCESS); + &mdb_value_, MDB_GET_CURRENT), MDB_SUCCESS); data.ParseFromArray(mdb_value_.mv_data, mdb_value_.mv_size); break; @@ -434,7 +451,7 @@ void DrivingDataLayer::InternalThreadEntry() { } // Apply data transformations - //this->data_transformer_.Transform(item_id, datum, this->mean_, top_data); + // this->data_transformer_.Transform(item_id, datum, this->mean_, top_data); const Datum& img_datum = data.car_image_datum(); const string& img_datum_data = img_datum.data(); int h_off = img_datum.height() == data.car_cropped_height() ? @@ -444,11 +461,15 @@ void DrivingDataLayer::InternalThreadEntry() { for (int c = 0; c < img_datum.channels(); ++c) { for (int h = 0; h < data.car_cropped_height(); ++h) { for (int w = 0; w < data.car_cropped_width(); ++w) { - int top_index = ((item_id * img_datum.channels() + c) * data.car_cropped_height() + h) + int top_index = ((item_id * img_datum.channels() + c) \ + * data.car_cropped_height() + h) * data.car_cropped_width() + w; - int data_index = (c * img_datum.height() + h + h_off) * img_datum.width() + w + w_off; - Dtype datum_element = - static_cast(static_cast(img_datum_data[data_index])); + int data_index = (c * img_datum.height() + h + h_off) \ + * img_datum.width() + w + w_off; + uint8_t datum_element_ui8 = \ + static_cast(img_datum_data[data_index]); + Dtype datum_element = static_cast(datum_element_ui8); + top_data[top_index] = datum_element - this->mean_[data_index]; } } @@ -457,7 +478,8 @@ void DrivingDataLayer::InternalThreadEntry() { vector label_datums(kNumLabels); if (this->output_labels_) { // Call appropriate functions for genearting each label - ReadBoundingBoxLabelToDatum(data, &label_datums[CAR_MERGED_LABELS], h_off, w_off); + ReadBoundingBoxLabelToDatum(data, &label_datums[CAR_MERGED_LABELS], + h_off, w_off); } for (int i = 0; i < kNumLabels; ++i) { for (int c = 0; c < label_datums[i].channels(); ++c) { @@ -465,8 +487,10 @@ void DrivingDataLayer::InternalThreadEntry() { for (int w = 0; w < label_datums[i].width(); ++w) { const int top_index = ((item_id * label_datums[i].channels() + c) * label_datums[i].height() + h) * label_datums[i].width() + w; - const int data_index = (c * label_datums[i].height() + h) * label_datums[i].width() + w; - top_labels[i][top_index] = static_cast(label_datums[i].float_data(data_index)); + const int data_index = (c * label_datums[i].height() + h) * \ + label_datums[i].width() + w; + float label_datum_elem = label_datums[i].float_data(data_index); + top_labels[i][top_index] = static_cast(label_datum_elem); } } } @@ -483,13 +507,16 @@ void DrivingDataLayer::InternalThreadEntry() { } break; case DataParameter_DB_LMDB: - if (mdb_cursor_get(mdb_cursor_, &mdb_key_, - &mdb_value_, MDB_NEXT) != MDB_SUCCESS) { - // We have reached the end. Restart from the first. - DLOG(INFO) << "Restarting data prefetching from start."; - CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_, - &mdb_value_, MDB_FIRST), MDB_SUCCESS); + for (int i = 0; i < Caffe::mpi()->size(); ++i) { + if (mdb_cursor_get(mdb_cursor_, &mdb_key_, + &mdb_value_, MDB_NEXT) != MDB_SUCCESS) { + // We have reached the end. Restart from the first. + DLOG(INFO) << "Restarting data prefetching from start."; + CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_, + &mdb_value_, MDB_FIRST), MDB_SUCCESS); + } } + break; default: LOG(FATAL) << "Unknown database backend"; diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index c13b19b3ff9..6d9110747d4 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -3,6 +3,8 @@ #include #include #include +#include // std::cout, std::ios +#include #include "caffe/net.hpp" #include "caffe/proto/caffe.pb.h" @@ -11,14 +13,10 @@ #include "caffe/util/math_functions.hpp" #include "caffe/util/upgrade_proto.hpp" - -#include -#include -#include -#include -#include // std::cout, std::ios -#include - +#include // NOLINT +#include // NOLINT +#include // NOLINT +#include // NOLINT namespace caffe { @@ -182,11 +180,13 @@ void Solver::Solve(const char* resume_file) { // For a network that is trained by the solver, no bottom or top vecs // should be given, and we will just provide dummy vecs. vector*> bottom_vec; + SyncData(); for (; iter_ < param_.max_iter(); ++iter_) { // Save a snapshot if needed. if (param_.snapshot() && iter_ > start_iter && iter_ % param_.snapshot() == 0) { - Snapshot(); + SyncData(); + if (Caffe::mpi()->rank() == 0) { Snapshot(); } } if (param_.test_interval() && iter_ % param_.test_interval() == 0 @@ -200,108 +200,6 @@ void Solver::Solve(const char* resume_file) { if (display) { LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss; const vector*>& result = net_->output_blobs(); - - //added by Tao. for debugging purpose only - /* - const vector& blob_names = net_->blob_names(); - const vector > >& blobs = net_->blobs(); - string str1("data"); - string str2("label"); - string str3("pixel-label"); - string str4("bb-label"); - string save_dir("/scr/twangcat/caffenet_results/train/"); - vector save_imgs; - int quad_height; - int quad_width; - int batch_size; - const Dtype* pix_start; - const Dtype* bb_start; - for (int j = 0; j < blobs.size(); ++j) { - if(blob_names[j].compare(str3)==0) //pixel label - { - LOG(INFO) << "pixel-label " << blobs[j]->num()<<" "<channels()<<" "<height()<<" "<width(); - pix_start = blobs[j]->cpu_data(); - quad_height = blobs[j]->height(); - quad_width = blobs[j]->width(); - batch_size = blobs[j]->num(); - } - if(blob_names[j].compare(str4)==0) // bb label - { - LOG(INFO) << "bb-label " << blobs[j]->num()<<" "<channels()<<" "<height()<<" "<width(); - bb_start = blobs[j]->cpu_data(); - } - if(blob_names[j].compare(str1)==0) // actual image - { - LOG(INFO) << "data " << blobs[j]->num()<<" "<channels()<<" "<height()<<" "<width(); - const Dtype* data_start = blobs[j]->cpu_data(); - for(int n=0; nnum(); ++n) - { - cv::Mat curr_img = cv::Mat(blobs[j]->height(), blobs[j]->width(), CV_32FC3, cv::Scalar(0,0,255)); - for(int kk=0; kkchannels();++kk) - { - for(int yy=0; yyheight();++yy) - { - for(int xx=0; xxwidth();++xx) - { - curr_img.at(yy,xx)[kk]=*(data_start+(((n*blobs[j]->channels() + kk) * blobs[j]->height() + yy) * blobs[j]->width() + xx)); - } - } - } - save_imgs.push_back(curr_img); - } - } - } - int grid_dim=4; - int label_count = 0; - int label_height = quad_height*grid_dim; - int label_width = quad_width*grid_dim; - Dtype scaling = 1.0/8; - for(int n=0; niter_*5+n; - cv::Mat save_img = save_imgs[n]; - std::ostringstream stringStream; - stringStream <(y/scaling,x/scaling) = cv::Vec3f(0,255,0); - save_img.at(y/scaling-1,x/scaling-1) = cv::Vec3f(0,255,0); - save_img.at(y/scaling+1,x/scaling-1) = cv::Vec3f(0,255,0); - save_img.at(y/scaling-1,x/scaling+1) = cv::Vec3f(0,255,0); - save_img.at(y/scaling+1,x/scaling+1) = cv::Vec3f(0,255,0); - - float x_adj = (qx*grid_dim + grid_dim / 2) / scaling; - float y_adj = (qy*grid_dim + grid_dim / 2) / scaling; - //std::cout<<*(bb_start+(((n*64+z)*quad_height+qy)*quad_width+qx))<<" "; - //std::cout<<*(bb_start+(((n*64+z+16)*quad_height+qy)*quad_width+qx))<<" "; - //std::cout<<*(bb_start+(((n*64+z+32)*quad_height+qy)*quad_width+qx))<<" "; - //std::cout<<*(bb_start+(((n*64+z+48)*quad_height+qy)*quad_width+qx))<cpu_data(); @@ -325,11 +223,15 @@ void Solver::Solve(const char* resume_file) { } ComputeUpdateValue(); + SyncDiff(); net_->Update(); } // Always save a snapshot after optimization, unless overridden by setting // snapshot_after_train := false. - if (param_.snapshot_after_train()) { Snapshot(); } + if (param_.snapshot_after_train()) { + SyncData(); + if (Caffe::mpi()->rank() == 0) { Snapshot(); } + } // After the optimization is done, run an additional train and test pass to // display the train and test loss/outputs if appropriate (based on the // display and test_interval settings, respectively). Unlike in the rest of @@ -514,6 +416,25 @@ void SGDSolver::PreSolve() { } +template +void SGDSolver::SyncData() { + vector > >& net_params = this->net_->params(); + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + net_params[param_id]->SyncData(); + history_[param_id]->SyncData(); + } +} + + +template +void SGDSolver::SyncDiff() { + vector > >& net_params = this->net_->params(); + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + net_params[param_id]->SyncDiff(); + } +} + + template void SGDSolver::ComputeUpdateValue() { vector > >& net_params = this->net_->params(); diff --git a/src/caffe/util/mpi.cpp b/src/caffe/util/mpi.cpp new file mode 100644 index 00000000000..ac1697421c1 --- /dev/null +++ b/src/caffe/util/mpi.cpp @@ -0,0 +1,147 @@ +#ifdef USE_MPI +#include +#endif +#include + +#include "caffe/common.hpp" +#include "caffe/util/mpi.hpp" + + +#ifdef USE_MPI +// MPI: various checks for different function calls. +#define MPI_CHECK(condition) \ + do { \ + int error = condition; \ + CHECK_EQ(error, MPI_SUCCESS) << " MPI error " << error << " in file '" \ + << __FILE__ << "' at line " << __LINE__; \ + } while (0) +#endif + + +namespace caffe { + + +static bool distributed = false; + + +void caffe_init_mpi(int* pargc, char*** pargv) { + const char* local_rank_env = std::getenv("MV2_COMM_WORLD_LOCAL_RANK"); + shared_ptr mpi; + // We have launched with mpirun_rsh and will use MPI + if (local_rank_env) { +#ifdef USE_MPI + distributed = true; + int local_rank = std::atoi(local_rank_env); + Caffe::SetDevice(local_rank); + int provided, requested = MPI_THREAD_MULTIPLE; + MPI_CHECK(MPI_Init_thread(pargc, pargv, requested, &provided)); + CHECK_EQ(requested, provided) << "Thread level provided is too low"; + mpi.reset(new MPIDist()); + Caffe::set_device_state(Caffe::FIXED); + LOG(INFO) << "Rank: " << mpi->rank() << " set device to: " + << local_rank; +#endif + } else { + // Use the local version of MPI which acts as a dummy class + LOG(INFO) << "Running locally"; + distributed = false; + mpi.reset(new MPILocal()); + } + // Turn off logging for all ranks except 0 + if (mpi->rank() > 0) { + FLAGS_minloglevel = 4; + // ostringstream rank_str; + // rank_str << mpi->rank(); + // FLAGS_log_dir = FLAGS_log_dir + "/" + rank_str.str(); + // FLAGS_stderrthreshold = 4; + } + + Caffe::set_mpi(mpi); +} + + +void caffe_finalize_mpi() { +#ifdef USE_MPI + if (distributed) + MPI_CHECK(MPI_Finalize()); +#endif +} + + +MPIDist::MPIDist() { +#ifdef USE_MPI + MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank_)); + MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &size_)); +#endif +} + + +void MPIDist::Allreduce(const int count, float *sendrecv_buf) { +#ifdef USE_MPI + MPI_CHECK(MPI_Allreduce(MPI_IN_PLACE, sendrecv_buf, + count, MPI_FLOAT, + MPI_SUM, MPI_COMM_WORLD)); +#endif +} + + +void MPIDist::Allreduce(const int count, double *sendrecv_buf) { +#ifdef USE_MPI + MPI_CHECK(MPI_Allreduce(MPI_IN_PLACE, sendrecv_buf, + count, MPI_DOUBLE, + MPI_SUM, MPI_COMM_WORLD)); +#endif +} + + +void MPIDist::Allreduce(const int count, int *sendrecv_buf) { +#ifdef USE_MPI + MPI_CHECK(MPI_Allreduce(MPI_IN_PLACE, sendrecv_buf, + count, MPI_INT, + MPI_SUM, MPI_COMM_WORLD)); +#endif +} + + +void MPIDist::Allreduce(const int count, unsigned int *sendrecv_buf) { +#ifdef USE_MPI + MPI_CHECK(MPI_Allreduce(MPI_IN_PLACE, sendrecv_buf, + count, MPI_UNSIGNED, + MPI_SUM, MPI_COMM_WORLD)); +#endif +} + + +void MPIDist::Bcast(const int count, float *buffer, const int root) { +#ifdef USE_MPI + MPI_CHECK(MPI_Bcast(buffer, count, MPI_FLOAT, + root, MPI_COMM_WORLD)); +#endif +} + + +void MPIDist::Bcast(const int count, double *buffer, const int root) { +#ifdef USE_MPI + MPI_CHECK(MPI_Bcast(buffer, count, MPI_DOUBLE, + root, MPI_COMM_WORLD)); +#endif +} + + +void MPIDist::Bcast(const int count, int *buffer, const int root) { +#ifdef USE_MPI + MPI_CHECK(MPI_Bcast(buffer, count, MPI_INT, + root, MPI_COMM_WORLD)); +#endif +} + + +void MPIDist::Bcast(const int count, unsigned int *buffer, const int root) { +#ifdef USE_MPI + MPI_CHECK(MPI_Bcast(buffer, count, MPI_UNSIGNED, + root, MPI_COMM_WORLD)); +#endif +} + + +} // namespace caffe diff --git a/tools/caffe.cpp b/tools/caffe.cpp index c8c8c1a6b4c..7e4081aabf5 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -279,9 +279,12 @@ int main(int argc, char** argv) { " time benchmark model execution time"); // Run tool or show usage. caffe::GlobalInit(&argc, &argv); + int result = 0; if (argc == 2) { - return GetBrewFunction(caffe::string(argv[1]))(); + result = GetBrewFunction(caffe::string(argv[1]))(); } else { gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe"); } + caffe::GlobalFinalize(); + return result; }