Skip to content

Commit

Permalink
Implemented data parallel with MPI
Browse files Browse the repository at this point in the history
  • Loading branch information
brodyh committed Feb 12, 2015
1 parent ae89c1d commit 9384dd8
Show file tree
Hide file tree
Showing 15 changed files with 459 additions and 146 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ project( Caffe )

option(CPU_ONLY "Build Caffe without GPU support" OFF)
option(USE_CUDNN "Build Caffe with cuDNN support" ON)
option(USE_MPI "Build Caffe with MPI support" ON)
option(BUILD_PYTHON "Build Python wrapper" ON)
option(BUILD_MATLAB "Build Matlab wrapper" OFF)
option(BUILD_EXAMPLES "Build examples" ON)
Expand Down
19 changes: 19 additions & 0 deletions include/caffe/blob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,25 @@ class Blob {
*/
void ShareDiff(const Blob& other);

/**
* @brief Averages the data buffers of all the nodes using
* MPI. Performs an AllReduce followed by an averaging.
*/
void SyncData();
/**
* @brief Averages the diff buffers of all the nodes using
* MPI. Performs an AllReduce followed by an averaging.
*/
void SyncDiff();

/**
* @brief Uses a single rank to broadcast their data to all the
* other nodes. Only the rank specifed keeps their data every other
* rank's data is replaced. Currently just uses cpu due to its usage
* case being in filler.
*/
void BcastData(const int rank = 0);

protected:
shared_ptr<SyncedMemory> data_;
shared_ptr<SyncedMemory> diff_;
Expand Down
21 changes: 19 additions & 2 deletions include/caffe/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <vector>

#include "caffe/util/device_alternate.hpp"
#include "caffe/util/mpi.hpp"

// gflags 2.1 issue: namespace google was changed to gflags without warning.
// Luckily we will be able to use GFLAGS_GFAGS_H_ to detect if it is version
Expand Down Expand Up @@ -62,9 +63,13 @@ using std::stringstream;
using std::vector;

// A global initialization function that you should call in your main function.
// Currently it initializes google flags and google logging.
// Currently it initializes google flags, google logging, and MPI.
void GlobalInit(int* pargc, char*** pargv);

// A global finalization function that you should call at the end of
// your main function.
void GlobalFinalize();

// A singleton class to hold common caffe stuff, such as the handler that
// caffe is going to use for cublas, curand, etc.
class Caffe {
Expand All @@ -78,7 +83,7 @@ class Caffe {
}
enum Brew { CPU, GPU };
enum Phase { TRAIN, TEST };

enum DeviceState { FIXED, MUTABLE };

// This random number generator facade hides boost and CUDA rng
// implementation from one another (for cross-platform compatibility).
Expand Down Expand Up @@ -112,6 +117,10 @@ class Caffe {
inline static Brew mode() { return Get().mode_; }
// Returns the phase: TRAIN or TEST.
inline static Phase phase() { return Get().phase_; }
// Returns the MPI implementation
inline static shared_ptr<MPI> mpi() { return Get().mpi_; }
// Returns the DeviceState: MUTABLE or FIXED
inline static DeviceState device_state() { return Get().device_state_; }
// The setters for the variables
// Sets the mode. It is recommended that you don't change the mode halfway
// into the program since that may cause allocation of pinned memory being
Expand All @@ -120,23 +129,31 @@ class Caffe {
inline static void set_mode(Brew mode) { Get().mode_ = mode; }
// Sets the phase.
inline static void set_phase(Phase phase) { Get().phase_ = phase; }
inline static void set_mpi(shared_ptr<MPI> mpi) { Get().mpi_ = mpi; }
inline static void set_device_state(DeviceState device_state) {
Get().device_state_ = device_state;
}
// Sets the random seed of both boost and curand
static void set_random_seed(const unsigned int seed);
// Sets the device. Since we have cublas and curand stuff, set device also
// requires us to reset those values.
static void SetDevice(const int device_id);
// Prints the current GPU status.
static void DeviceQuery();
// Returns the number of free bytes on GPU
static size_t DeviceMemoryFree();

protected:
#ifndef CPU_ONLY
cublasHandle_t cublas_handle_;
curandGenerator_t curand_generator_;
#endif
shared_ptr<RNG> random_generator_;
shared_ptr<MPI> mpi_;

Brew mode_;
Phase phase_;
DeviceState device_state_;
static shared_ptr<Caffe> singleton_;

private:
Expand Down
4 changes: 4 additions & 0 deletions include/caffe/filler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class UniformFiller : public Filler<Dtype> {
Dtype(this->filler_param_.max()), blob->mutable_cpu_data());
CHECK_EQ(this->filler_param_.sparse(), -1)
<< "Sparsity not supported by this Filler.";
blob->BcastData();
}
};

Expand Down Expand Up @@ -90,6 +91,7 @@ class GaussianFiller : public Filler<Dtype> {
data[i] *= mask[i];
}
}
blob->BcastData();
}

protected:
Expand Down Expand Up @@ -123,6 +125,7 @@ class PositiveUnitballFiller : public Filler<Dtype> {
}
CHECK_EQ(this->filler_param_.sparse(), -1)
<< "Sparsity not supported by this Filler.";
blob->BcastData();
}
};

Expand Down Expand Up @@ -154,6 +157,7 @@ class XavierFiller : public Filler<Dtype> {
blob->mutable_cpu_data());
CHECK_EQ(this->filler_param_.sparse(), -1)
<< "Sparsity not supported by this Filler.";
blob->BcastData();
}
};

Expand Down
6 changes: 6 additions & 0 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class Solver {
// PreSolve is run before any solving iteration starts, allowing one to
// put up some scaffold.
virtual void PreSolve() {}
// Sync net parameters & any solver data
virtual void SyncData() = 0;
// Sync gradients
virtual void SyncDiff() = 0;
// Get the update value for the current iteration.
virtual void ComputeUpdateValue() = 0;
// The Solver::Snapshot function implements the basic snapshotting utility
Expand Down Expand Up @@ -80,6 +84,8 @@ class SGDSolver : public Solver<Dtype> {
protected:
virtual void PreSolve();
Dtype GetLearningRate();
virtual void SyncDiff();
virtual void SyncData();
virtual void ComputeUpdateValue();
virtual void SnapshotSolverState(SolverState * state);
virtual void RestoreSolverState(const SolverState& state);
Expand Down
77 changes: 77 additions & 0 deletions include/caffe/util/mpi.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#ifndef CAFFE_MPI_HPP_
#define CAFFE_MPI_HPP_

#include "caffe/common.hpp"

namespace caffe {


void caffe_init_mpi(int* pargc, char*** pargv);
void caffe_finalize_mpi();


class MPI {
public:
MPI() : rank_(0), size_(1) {}
virtual inline int rank() const { return rank_; }
virtual inline int size() const { return size_; }
// Can't make a virtual template function :(
virtual void Allreduce(const int count, float *sendrecv_buf) = 0;
virtual void Allreduce(const int count, double *sendrecv_buf) = 0;
virtual void Allreduce(const int count, int *sendrecv_buf) = 0;
virtual void Allreduce(const int count, unsigned int *sendrecv_buf) = 0;
virtual void Bcast(const int count, float *buffer, const int root = 0) = 0;
virtual void Bcast(const int count, double *buffer, const int root = 0) = 0;
virtual void Bcast(const int count, int *buffer, const int root = 0) = 0;
virtual void Bcast(const int count, unsigned int *buffer, const int root = 0) = 0;

protected:
int rank_, size_;

// DISABLE_COPY_AND_ASSIGN(MPI);
private:
MPI(const MPI&);
MPI& operator=(const MPI&);
};

class MPILocal : public MPI {
public:
MPILocal() : MPI() {}
virtual void Allreduce(const int count, float *sendrecv_buf) {}
virtual void Allreduce(const int count, double *sendrecv_buf) {}
virtual void Allreduce(const int count, int *sendrecv_buf) {}
virtual void Allreduce(const int count, unsigned int *sendrecv_buf) {}
virtual void Bcast(const int count, float *buffer, const int root = 0) {}
virtual void Bcast(const int count, double *buffer, const int root = 0) {}
virtual void Bcast(const int count, int *buffer, const int root = 0) {}
virtual void Bcast(const int count, unsigned int *buffer, const int root = 0) {}

// DISABLE_COPY_AND_ASSIGN(MPILocal);
private:
MPILocal(const MPILocal&);
MPILocal& operator=(const MPILocal&);
};

class MPIDist : public MPI {
public:
MPIDist();
virtual void Allreduce(const int count, float *sendrecv_buf);
virtual void Allreduce(const int count, double *sendrecv_buf);
virtual void Allreduce(const int count, int *sendrecv_buf);
virtual void Allreduce(const int count, unsigned int *sendrecv_buf);
virtual void Bcast(const int count, float *buffer, const int root = 0);
virtual void Bcast(const int count, double *buffer, const int root = 0);
virtual void Bcast(const int count, int *buffer, const int root = 0);
virtual void Bcast(const int count, unsigned int *buffer, const int root = 0);

// DISABLE_COPY_AND_ASSIGN(MPIDist);
private:
MPIDist(const MPIDist&);
MPIDist& operator=(const MPIDist&);
};


} // namespace caffe


#endif // CAFFE_MPI_HPP_
4 changes: 2 additions & 2 deletions models/brody/train_val_driving_normalization.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ layers {
data_param {
source: "/deep/u/willsong/caffe/driving_train_rgb"
backend: LMDB
batch_size: 10
batch_size: 24
}
transform_param {
mean_file: "/deep/u/willsong/caffe/driving_mean_rgb.binaryproto"
Expand All @@ -26,7 +26,7 @@ layers {
data_param {
source: "/deep/u/willsong/caffe/driving_test_rgb"
backend: LMDB
batch_size: 10
batch_size: 2
}
transform_param {
mean_file: "/deep/u/willsong/caffe/driving_mean_rgb.binaryproto"
Expand Down
4 changes: 2 additions & 2 deletions models/bvlc_reference_caffenet/train_val.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ layers {
top: "data"
top: "label"
data_param {
source: "/deep/u/willsong/data/ilsvrc12_train_lmdb"
source: "examples/imagenet/ilsvrc12_train_lmdb"
backend: LMDB
batch_size: 256
}
Expand All @@ -22,7 +22,7 @@ layers {
top: "data"
top: "label"
data_param {
source: "/deep/u/willsong/data/ilsvrc12_val_lmdb"
source: "examples/imagenet/ilsvrc12_val_lmdb"
backend: LMDB
batch_size: 50
}
Expand Down
12 changes: 10 additions & 2 deletions src/caffe/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,15 @@ include_directories(${LMDB_INCLUDE_DIR})

# Boost
find_package(Boost 1.46 COMPONENTS system thread REQUIRED)
include_directories( ${Boost_INCLUDE_DIR} )
link_directories( ${Boost_LIBRARY_DIRS} )
include_directories(${Boost_INCLUDE_DIR})
link_directories(${Boost_LIBRARY_DIRS})

# MPI
if (USE_MPI)
add_definitions(-DUSE_MPI)
find_package(MPI REQUIRED)
include_directories(${MPI_INCLUDE_PATH})
endif()

add_subdirectory(proto)

Expand Down Expand Up @@ -119,6 +126,7 @@ target_link_libraries(caffe proto
${LEVELDB_LIBS}
${LMDB_LIBRARIES}
${OpenCV_LIBS}
${MPI_LIBRARIES}
)

#set output directory
Expand Down
65 changes: 64 additions & 1 deletion src/caffe/blob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/mpi.hpp"

namespace caffe {

Expand Down Expand Up @@ -103,6 +104,69 @@ void Blob<Dtype>::ShareDiff(const Blob& other) {
diff_ = other.diff();
}


// The "SyncData" method is used for parameter blobs in a Net, which are stored
// as Blob<float> or Blob<double> -- hence we do not define it for
// Blob<int> or Blob<unsigned int>.
template <> void Blob<unsigned int>::SyncData() { NOT_IMPLEMENTED; }
template <> void Blob<int>::SyncData() { NOT_IMPLEMENTED; }

template <typename Dtype>
void Blob<Dtype>::SyncData() {
shared_ptr<MPI> mpi = Caffe::mpi();
switch (Caffe::mode()) {
case Caffe::CPU:
mpi->Allreduce(count_, mutable_cpu_data());
caffe_scal(count_, 1/Dtype(mpi->size()), mutable_cpu_data());
break;
case Caffe::GPU:
#ifndef CPU_ONLY
mpi->Allreduce(count_, mutable_gpu_data());
caffe_gpu_scal(count_, 1/Dtype(mpi->size()), mutable_gpu_data());
#else
NO_GPU;
#endif
break;
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}

// The "SyncDiff" method is used for parameter blobs in a Net, which are stored
// as Blob<float> or Blob<double> -- hence we do not define it for
// Blob<int> or Blob<unsigned int>.
template <> void Blob<unsigned int>::SyncDiff() { NOT_IMPLEMENTED; }
template <> void Blob<int>::SyncDiff() { NOT_IMPLEMENTED; }

template <typename Dtype>
void Blob<Dtype>::SyncDiff() {
shared_ptr<MPI> mpi = Caffe::mpi();
switch (Caffe::mode()) {
case Caffe::CPU:
mpi->Allreduce(count_, mutable_cpu_diff());
caffe_scal(count_, 1/Dtype(mpi->size()), mutable_cpu_diff());
break;
case Caffe::GPU:
#ifndef CPU_ONLY
mpi->Allreduce(count_, mutable_gpu_diff());
caffe_gpu_scal(count_, 1/Dtype(mpi->size()), mutable_gpu_diff());
#else
NO_GPU;
#endif
break;
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}


template <typename Dtype>
void Blob<Dtype>::BcastData(const int rank) {
shared_ptr<MPI> mpi = Caffe::mpi();
mpi->Bcast(count_, mutable_cpu_data(), rank);
}


// The "update" method is used for parameter blobs in a Net, which are stored
// as Blob<float> or Blob<double> -- hence we do not define it for
// Blob<int> or Blob<unsigned int>.
Expand Down Expand Up @@ -299,4 +363,3 @@ template class Blob<int>;
template class Blob<unsigned int>;

} // namespace caffe

Loading

0 comments on commit 9384dd8

Please sign in to comment.