Skip to content

Commit

Permalink
more fine grained config options
Browse files Browse the repository at this point in the history
  • Loading branch information
Roger Waleffe authored and Roger Waleffe committed Nov 21, 2023
1 parent ae0c85c commit 72ce035
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
4 changes: 4 additions & 0 deletions src/cpp/include/configuration/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,14 @@ struct PipelineConfig {
int gradients_device_queue_size;
int gradients_host_queue_size;
int batch_loader_threads;
int remote_loader_threads;
int batch_transfer_threads;
int remote_transfer_threads;
int compute_threads;
int gradient_transfer_threads;
int remote_gradient_transfer_threads;
int gradient_update_threads;
int remote_listen_threads;
};

struct CheckpointConfig {
Expand Down
4 changes: 4 additions & 0 deletions src/cpp/src/configuration/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,10 +456,14 @@ shared_ptr<PipelineConfig> initPipelineConfig(pyobj python_config) {
ret_config->gradients_device_queue_size = cast_helper<int>(python_config.attr("gradients_device_queue_size"));
ret_config->gradients_host_queue_size = cast_helper<int>(python_config.attr("gradients_host_queue_size"));
ret_config->batch_loader_threads = cast_helper<int>(python_config.attr("batch_loader_threads"));
ret_config->remote_loader_threads = cast_helper<int>(python_config.attr("remote_loader_threads"));
ret_config->batch_transfer_threads = cast_helper<int>(python_config.attr("batch_transfer_threads"));
ret_config->remote_transfer_threads = cast_helper<int>(python_config.attr("remote_transfer_threads"));
ret_config->compute_threads = cast_helper<int>(python_config.attr("compute_threads"));
ret_config->gradient_transfer_threads = cast_helper<int>(python_config.attr("gradient_transfer_threads"));
ret_config->remote_gradient_transfer_threads = cast_helper<int>(python_config.attr("remote_gradient_transfer_threads"));
ret_config->gradient_update_threads = cast_helper<int>(python_config.attr("gradient_update_threads"));
ret_config->remote_listen_threads = cast_helper<int>(python_config.attr("remote_listen_threads"));
}

return ret_config;
Expand Down
8 changes: 4 additions & 4 deletions src/cpp/src/pipeline/pipeline_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,12 +565,12 @@ void PipelineGPU::initialize() {
addWorkersToPool(0, LOAD_BATCH_ID, pipeline_options_->batch_loader_threads);

if (batch_worker_ and batch_worker_needs_remote_)
addWorkersToPool(5, REMOTE_TO_DEVICE_ID, pipeline_options_->batch_transfer_threads);
addWorkersToPool(5, REMOTE_TO_DEVICE_ID, pipeline_options_->remote_transfer_threads);
else if (compute_worker_)
addWorkersToPool(1, H2D_TRANSFER_ID, pipeline_options_->batch_transfer_threads);

if (compute_worker_ and compute_worker_needs_remote_)
addWorkersToPool(6, REMOTE_LOADER_ID, pipeline_options_->batch_loader_threads);
addWorkersToPool(6, REMOTE_LOADER_ID, pipeline_options_->remote_loader_threads);

if (compute_worker_)
addWorkersToPool(2, GPU_COMPUTE_ID, 1, model_->device_models_.size()); // Only one std::thread manages GPU
Expand All @@ -579,12 +579,12 @@ void PipelineGPU::initialize() {
addWorkersToPool(3, D2H_TRANSFER_ID, pipeline_options_->gradient_transfer_threads);

if ((compute_worker_ and compute_worker_needs_remote_) or (batch_worker_ and batch_worker_needs_remote_))
addWorkersToPool(8, REMOTE_TO_HOST_ID, pipeline_options_->gradient_transfer_threads);
addWorkersToPool(8, REMOTE_TO_HOST_ID, pipeline_options_->remote_gradient_transfer_threads);
else if (model_->has_embeddings() and batch_worker_)
addWorkersToPool(4, UPDATE_BATCH_ID, pipeline_options_->gradient_update_threads);

if (batch_worker_ and batch_worker_needs_remote_)
addWorkersToPool(7, REMOTE_LISTEN_FOR_UPDATES_ID, pipeline_options_->gradient_update_threads);
addWorkersToPool(7, REMOTE_LISTEN_FOR_UPDATES_ID, pipeline_options_->remote_listen_threads);

} else {
if (batch_worker_)
Expand Down
4 changes: 4 additions & 0 deletions src/python/tools/configuration/marius_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,10 +753,14 @@ class PipelineConfig:
gradients_device_queue_size: int = 4
gradients_host_queue_size: int = 4
batch_loader_threads: int = 4
remote_loader_threads: int = 4
batch_transfer_threads: int = 2
remote_transfer_threads: int = 2
compute_threads: int = 1
gradient_transfer_threads: int = 2
remote_gradient_transfer_threads: int = 2
gradient_update_threads: int = 4
remote_listen_threads: int = 4

def __post_init__(self):
# for the sync setting, pipeline values can be ignored
Expand Down

0 comments on commit 72ce035

Please sign in to comment.