From c54824e435681bdbd59b90dd46f1504af7090d26 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 24 Apr 2024 20:38:00 +0000 Subject: [PATCH 01/12] refractored distributed jobs --- docs/multinode.md | 173 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 145 insertions(+), 28 deletions(-) diff --git a/docs/multinode.md b/docs/multinode.md index 55dbc4d..6b3f431 100644 --- a/docs/multinode.md +++ b/docs/multinode.md @@ -1,4 +1,4 @@ -# Multi-Node Training with RunAI +# Distributed PyTorch with RunAI > [!NOTE] > Multi-Node scheduling needs to be enabled on the cluster and you should be using a RunAI CLI which > supports multi-node jobs. @@ -8,17 +8,15 @@ Jobs can be submitted either through RunAI as documented in RunAI's website (https://docs.run.ai/v2.13/Researcher/cli-reference/runai-submit-dist-pytorch/). -As an example, the following command launches 3 pods, each with 4 GPUs. Note that the number of pods is one more than the number of workers as the master node is not counted as a worker. +To execute jobs in RCP, we will use the RunAI CLI, more specifically the `submit-dist pytorch` function, which will be responsible for launching the specified command on each pod. There are two ways to execute distributed applications: +1. Interactive sessions. To force interactive sessions, we will have to launch the command `sleep infinity` on each pod. This way, we can connect to each pod, but we will have to manually execute the jobs on each one. This is useful for short sessions for debugging applications and checking that everything works correctly before launching a longer job. + > TIP + > Keep in mind that as soon as you disconnect from the pod, you will lose the current job you are executing. +2. Batched execution. In this mode, we will specify to the `submit-dist` function to execute a script, and it will defer execution until the requested resources are available. This is the recommended way to launch longer jobs such as model training. -```bash -runai submit-dist pytorch \ - --name distributed-job-readme \ - --workers=2 -g 4 -i ic-registry.epfl.ch/mlo/lauzhack:v1 \ - --annotation k8s.v1.cni.cncf.io/networks=kube-system/roce \ - --extended-resource rdma/rdma=1 \ - -- "sleep infinity" -``` -Note that it is not possbile to control how these pods are scheduled so these two pods can be either on the same node or on different nodes. For best performance, local GPUs should be maximized, which would mean asking for pods of 8 GPUs each (taking a full node). +To configure the number of nodes and GPUs, we will use the following flags of the `submit-dist` function: +1. `--workers`: The total number of nodes will be `n_workers` + 1, as RunAI adds a master node by default. +2. `--gpu`: The number of GPUs per node. Unless debugging applications, set this value as the number of GPUs per node. Otherwise, it would be possible to orchestrate 2 pods on the same node, which would not make sense. RunAI handles scheduling the pods and also creates the necessary communication (rendezvous) backend (most likely c10d) between them. The following environment variables are set: @@ -27,24 +25,9 @@ RunAI handles scheduling the pods and also creates the necessary communication ( * `MASTER_ADDR`: IP Address of the master node. * `MASTER_PORT`: Port on which master node is listening - -For running a training job, torchrun accepts the above variables as arguments and automatically schedules the job. For example the following command can be used to schedule a training job on the 3 pods we launched before. Note that the command needs to be run on each of the pods separately. - -```bash -torchrun \ - --nproc-per-node 4 \ - --nnodes ${WORLD_SIZE} \ - --node_rank ${RANK} \ - --master_addr ${MASTER_ADDR} \ - --master_port ${MASTER_PORT} \ - main.py -``` - -torchrun automatically launches a separate process for each GPU and assigns the correct global rank. As such, for basic usage (e.g. FSDP), no changes to python code is necessary. - ## Using RDMA for efficient inter-node communication -While the above should get a job running, additional setup is necessary for efficient communication, in particular, using RDMA. We have already specified the following flags when running our pods to ensure RDMA support: +Additional setup is necessary for efficient communication, in particular, using RDMA. We have already specified the following flags when running our pods to ensure RDMA support: ```--annotation k8s.v1.cni.cncf.io/networks=kube-system/roce --extended-resource rdma/rdma=1```. However, the communication backend requires additional configuration to use RDMA. In particular, the following steps are needed when using NCCL. The necessary steps may vary for different OS distributions or versions as well as when alternative drivers for Inifiniband/RDMA are installed. @@ -81,4 +64,138 @@ However, the communication backend requires additional configuration to use RDMA export NCCL_NSOCKS_PERTHREAD=8 ``` -4. You should run torchrun with the above environment variables set. This should usually be enough to get NCCL to correctly use RDMA. To verify this, you can use tools such as ifstats. These tools monitor network traffic that goes through CPU. When using RDMA, no such traffic should be visible (assuming you are not using the network interface for other things). +4. You should run `torchrun` with the above environment variables set. This should usually be enough to get NCCL to correctly use RDMA. To verify this, you can use tools such as ifstats. These tools monitor network traffic that goes through CPU. When using RDMA, no such traffic should be visible (assuming you are not using the network interface for other things). + +## Running your first distributed application +In [`/utils/distributed_pytorch/my_first_distributed_app/`](/utils/distributed_pytorch/my_first_distributed_app/), you will find everything necessary to run a distributed application in PyTorch. This application simply computes number PI using the trapezoid rule by distributing the integral among the total number of processes. + +To launch our first application, we will use the batched execution format from the `submit-dist pytorch` function. We will launch the job as follows to distribute the work across two nodes ([`/utils/distributed_pytorch/my_first_distributed_app/RUNAI_run_app.sh`](/utils/distributed_pytorch/my_first_distributed_app/RUNAI_run_app.sh)): + +``` +./runai-2.13.49 submit-dist pytorch \ + --name my_first_distributed_app \ + --image registry.rcp.epfl.ch/meditron-ddx/basic:latest-solergib \ + --workers 1 \ + --gpu 0 \ + --pvc mlo-scratch:/mloscratch \ + --annotation k8s.v1.cni.cncf.io/networks=kube-system/roce \ + --extended-resource rdma/rdma=1 \ + -e PATH_TO_ROOT_REPO=/mloscratch/homes/solergib/getting-started \ + --large-shm \ + -- bash -c '"source \${PATH_TO_ROOT_REPO}/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh &> \${PATH_TO_ROOT_REPO}/utils/distributed_pytorch/my_first_distributed_app/reports/Output_\${JOB_UUID}.txt"' +``` + +Note the following: +1. We aren't requesting any GPU, as the application doesn't needs any. +2. We include the annotations to use RDMA. +3. The environment variable `PATH_TO_ROOT_REPO` contains the path to this repository within the PVC `mlo-scratch` mounted at `/mlo-scratch`. +4. We launch the job with `bash -c "..."` to: + 1. Allow for the delayed interpolation of environment variables to work (e.g., `PATH_TO_ROOT_REPO`). + 2. Store the output of the job in a file. It can also be checked with `runai logs --name`, but after some time, it will be inaccessible. + > [!WARNING] + > Don't forget the double double quotes in `bash -c` (`'"..."'`). + +The script to be executed on each node is as follows ([`/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh`](/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh)): + +``` +#!/bin/bash + +echo "START TIME: $(date)" + +export NCCL_IB_GID_INDEX=$(grep 'RoCE v2' $(grep '0000:0000:0000:0000:0000:ffff' /sys/class/infiniband/mlx5_bond_0/ports/1/gids/* | cut -d ':' -f 1 | sed 's/gids/gid_attrs\/types/') | sed -e 's/.*\/\([0-9]*\):.*/\1/') +export NCCL_IB_HCA=mlx5 +export NCCL_SOCKET_NTHREADS=4 +export NCCL_NSOCKS_PERTHREAD=8 + +# MASTER_ADDR -> The IP of the master node +# MASTER_PORT -> The port that of the master node +# WORLD_SIZE -> Number of nodes in total, NOT Numer of nodes X GPUs per node +PROCESSES_PER_NODE=20 + +LAUNCHER="torchrun \ + --nproc_per_node $PROCESSES_PER_NODE \ + --nnodes $WORLD_SIZE \ + --node_rank $RANK \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --role \$(hostname -s|tr -dc '0-9'): \ + --tee 3 \ + " + +PYTHON_FILE=/mloscratch/homes/solergib/getting-started/utils/distributed_pytorch/my_first_distributed_app/my_first_distributed_app.py + +export CMD="$LAUNCHER $PYTHON_FILE" +bash -c "$CMD" + +echo "END TIME: $(date)" +``` + +Note the following: +1. At the beginning, we set both the network and environment configurations (Activate conda environment, set environment variables, etc.). +2. To launch the distributed applications, we will use `torchrun`. In short, `torchrun` spawns `--nproc-per-node` processes on each node by executing the specified script. Additionally, it also handles communications between nodes before launching the script. For this, it is necessary to specify `MASTER_ADDR` and `MASTER_PORT`, variables that are automatically defined by RunAI when using `submit-dist pytorch`. `--nodes` will be the number of pods launched (`WORLD_SIZE`), and we will use `--node-rank` to specify the rank of each node; otherwise, `torchrun` will assign a value to each `--node-rank`. In this example, for which we will not use GPUs, we will launch 20 processes on each of the two nodes, dividing the work among a total of 40 processes. + > [!WARNING] + > Do not confuse the variables `WORLD_SIZE` and `RANK` produced by RunAI with the `submit-dist function` with the same variables generated by `torchrun` when launching the scripts. In the case of RunAI, they are configured based on the **number of pods**, while in `torchrun`, they are configured based on the **number of spawned processes**, which is defined by `--nnodes` x `--nproc-per-node`. + +## Inter-Node communication benchmark +We conducted a benchmark to determine the bandwidth between nodes (In Gbps). As can be seen, the benefit of RDMA is significant, so it is advisable to ensure that it is enabled. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
RDMANO RDMA
GPUsbusbwalgbwbusbwalgbw
21687.11687.1--
41621.51081.0--
81662.4949.9--
16122.365.229.115.5
+ +Pay attention to the `busbw` result (not `algbw`) as explained [here](https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#bandwidth). For intra-node communications (GPUs on the same node), RDMA is disabled, so the data shown reflects the performance achieved with NVLINK. Keep in mind that to shard big models using DeepSpeed/FSDP, it is recommended to have at least 400 Gbps, so it is advisable to restrict training to a single node whenever possible. + +Both the benchmark and the script to launch the job in RCP are located in [`/utils/distributed_pytorch/benchmark/`](/utils/distributed_pytorch/benchmark/). This benchmark is a reduced version of `nccl-test` in Python developed by [Stas Bekman](https://github.com/stas00/ml-engineering/blob/master/network/benchmarks/all_reduce_bench.py). \ No newline at end of file From ff4624fd4240cd14f4d66ba796eb4cad78c66dd5 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 24 Apr 2024 20:38:34 +0000 Subject: [PATCH 02/12] added bw benchmark --- .../benchmark/RCP_run_benchmark.sh | 35 +++++ .../benchmark/RCP_run_benchmark_no_RDMA.sh | 30 ++++ .../benchmark/RUNAI_run_benchmark.sh | 11 ++ .../benchmark/RUNAI_run_benchmark_no_RDMA.sh | 9 ++ .../benchmark/all_reduce_bench.py | 148 ++++++++++++++++++ ...t_16f84e3c-379e-44ef-a4d0-ea595e1782a1.txt | 17 ++ ...t_92f3398e-bcce-4edc-ac95-f40c9b70135d.txt | 17 ++ 7 files changed, 267 insertions(+) create mode 100644 utils/distributed_pytorch/benchmark/RCP_run_benchmark.sh create mode 100644 utils/distributed_pytorch/benchmark/RCP_run_benchmark_no_RDMA.sh create mode 100644 utils/distributed_pytorch/benchmark/RUNAI_run_benchmark.sh create mode 100644 utils/distributed_pytorch/benchmark/RUNAI_run_benchmark_no_RDMA.sh create mode 100644 utils/distributed_pytorch/benchmark/all_reduce_bench.py create mode 100644 utils/distributed_pytorch/benchmark/reports/Output_16f84e3c-379e-44ef-a4d0-ea595e1782a1.txt create mode 100644 utils/distributed_pytorch/benchmark/reports/Output_92f3398e-bcce-4edc-ac95-f40c9b70135d.txt diff --git a/utils/distributed_pytorch/benchmark/RCP_run_benchmark.sh b/utils/distributed_pytorch/benchmark/RCP_run_benchmark.sh new file mode 100644 index 0000000..10749ff --- /dev/null +++ b/utils/distributed_pytorch/benchmark/RCP_run_benchmark.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +echo "START TIME: $(date)" + +export NCCL_IB_GID_INDEX=$(grep 'RoCE v2' $(grep '0000:0000:0000:0000:0000:ffff' /sys/class/infiniband/mlx5_bond_0/ports/1/gids/* | cut -d ':' -f 1 | sed 's/gids/gid_attrs\/types/') | sed -e 's/.*\/\([0-9]*\):.*/\1/') +export NCCL_IB_HCA=mlx5 +export NCCL_SOCKET_NTHREADS=4 +export NCCL_NSOCKS_PERTHREAD=8 + +# MASTER_ADDR -> The IP of the master node +# MASTER_PORT -> The port that of the master node +# WORLD_SIZE -> Number of nodes in total, NOT Numer of nodes X GPUs per node +GPUS_PER_NODE=8 + +LAUNCHER="torchrun \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $WORLD_SIZE \ + --node_rank $RANK \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --role \$(hostname -s|tr -dc '0-9'): \ + --tee 3 \ + " + +PYTHON_FILE=/mloscratch/homes/solergib/getting-started/utils/distributed_pytorch/benchmark/all_reduce_bench.py +PYTHON_ARGS=" \ + --batch_size 16 \ + --model Llama3 \ + " + +export CMD="$LAUNCHER $PYTHON_FILE $PYTHON_ARGS" +bash -c "$CMD" + +echo "END TIME: $(date)" \ No newline at end of file diff --git a/utils/distributed_pytorch/benchmark/RCP_run_benchmark_no_RDMA.sh b/utils/distributed_pytorch/benchmark/RCP_run_benchmark_no_RDMA.sh new file mode 100644 index 0000000..f8e3cd7 --- /dev/null +++ b/utils/distributed_pytorch/benchmark/RCP_run_benchmark_no_RDMA.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +echo "START TIME: $(date)" + +# MASTER_ADDR -> The IP of the master node +# MASTER_PORT -> The port that of the master node +# WORLD_SIZE -> Number of nodes in total, NOT Numer of nodes X GPUs per node +GPUS_PER_NODE=8 + +LAUNCHER="torchrun \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $WORLD_SIZE \ + --node_rank $RANK \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --role \$(hostname -s|tr -dc '0-9'): \ + --tee 3 \ + " + +PYTHON_FILE=/mloscratch/homes/solergib/getting-started/utils/distributed_pytorch/benchmark/all_reduce_bench.py +PYTHON_ARGS=" \ + --batch_size 16 \ + --model Llama3 \ + " + +export CMD="$LAUNCHER $PYTHON_FILE $PYTHON_ARGS" +bash -c "$CMD" + +echo "END TIME: $(date)" diff --git a/utils/distributed_pytorch/benchmark/RUNAI_run_benchmark.sh b/utils/distributed_pytorch/benchmark/RUNAI_run_benchmark.sh new file mode 100644 index 0000000..230974b --- /dev/null +++ b/utils/distributed_pytorch/benchmark/RUNAI_run_benchmark.sh @@ -0,0 +1,11 @@ +./runai-2.13.49 submit-dist pytorch \ + --name all-reduce-bench \ + --image registry.rcp.epfl.ch/meditron-ddx/basic:latest-solergib \ + --workers 3 \ + --gpu 8 \ + --pvc mlo-scratch:/mloscratch \ + --annotation k8s.v1.cni.cncf.io/networks=kube-system/roce \ + --extended-resource rdma/rdma=1 \ + -e PATH_TO_ROOT_REPO=/mloscratch/homes/solergib/getting-started \ + --large-shm \ + -- bash -c '"source \${PATH_TO_ROOT_REPO}/utils/distributed_pytorch/benchmark/RCP_run_benchmark.sh &> \${PATH_TO_ROOT_REPO}/utils/distributed_pytorch/benchmark/reports/Output_\${JOB_UUID}.txt"' \ No newline at end of file diff --git a/utils/distributed_pytorch/benchmark/RUNAI_run_benchmark_no_RDMA.sh b/utils/distributed_pytorch/benchmark/RUNAI_run_benchmark_no_RDMA.sh new file mode 100644 index 0000000..948aca2 --- /dev/null +++ b/utils/distributed_pytorch/benchmark/RUNAI_run_benchmark_no_RDMA.sh @@ -0,0 +1,9 @@ +./runai-2.13.49 submit-dist pytorch \ + --name all-reduce-bench \ + --image registry.rcp.epfl.ch/meditron-ddx/basic:latest-solergib \ + --workers 3 \ + --gpu 8 \ + --pvc mlo-scratch:/mloscratch \ + -e PATH_TO_ROOT_REPO=/mloscratch/homes/solergib/getting-started \ + --large-shm \ + -- bash -c '"source \${PATH_TO_ROOT_REPO}/utils/distributed_pytorch/benchmark/RCP_run_benchmark_no_RDMA.sh &> \${PATH_TO_ROOT_REPO}/utils/distributed_pytorch/benchmark/reports/Output_\${JOB_UUID}.txt"' \ No newline at end of file diff --git a/utils/distributed_pytorch/benchmark/all_reduce_bench.py b/utils/distributed_pytorch/benchmark/all_reduce_bench.py new file mode 100644 index 0000000..6b81796 --- /dev/null +++ b/utils/distributed_pytorch/benchmark/all_reduce_bench.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python + +""" + +The latest version of this program can be found at https://github.com/stas00/ml-engineering + +This benchmark is very similar to https://github.com/NVIDIA/nccl-tests but it's much easier to set +up as it only requires PyTorch to be installed + +This version: +- has been derived from @jeffra's gist: https://gist.github.com/jeffra/b5e80466b4c86be00ea3b6f130fb7a36 +- which in turn is derived from the logic in https://github.com/NVIDIA/nccl-tests +- with contributions from: + * Indu Thangakrishnan https://github.com/indhub to handle timing correctly using cuda events + + +Important notes: + +- when you finished running this benchmark you want to pay attention to the busbw result (not + algbw) as explained here https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#bandwidth + +- similar to NVIDIA/nccl-tests this benchmark measures a unidirectional bandwidth - so compare the + outcome against the advertised unidirectional peak throughput and not bi-directional (duplex) + +- currently this benchmark tests a payload of 4GB (M * N * 4). If your target application uses a + much smaller payload you want to modify M*N*4 to match the target payload. To calculate the + payload use the number of parameters sent in each reduction multiplied by 2 (bf16/fp16) or 4 + (fp32). e.g., if a reduction is of a single layer of 1B params, and you use bf16 grads it'd be + 2GB of payload. depending on the framework you use (DDP, FSDP, DeepSpeed ZeRO) they all use + different logic to how much of a message size they send. + +- if you are wondering whether you need to also run https://github.com/NVIDIA/nccl-tests - I + already validated that I got very similar results with ./build/all_reduce_perf -b 4G -e 4G + (tested with mpirun on 4 nodes). It should be either on par or slightly slower because it uses a + blocking approach - that is it wait for each new all_reduce to finish before firing the next + one, whereas nccl-tests fires them all in an async fashion (you can add `-z` to nccl-tests to + emulate blocking) + +- to benchmark other collectives use nccl-tests. It's also useful if you want to test a range of + payloads, e.g. there you'd set -b 8 -e 4G -f 2 and it will test many sizes automatically. + +To run on 4 nodes: + +GPUS_PER_NODE=8 +NNODES=4 +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 +python -u -m torch.distributed.run \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --role `hostname -s`: \ + --tee 3 \ + all_reduce_bench.py + +note: adapt MASTER_ADDR to node rank 0's hostname if it's not a SLURM environment where it's derived automatically + +e.g. example to run with salloc+srun: + +salloc --partition=mypartition --nodes=4 --ntasks-per-node=1 --cpus-per-task=48 --gres=gpu:8 --time=1:00:00 bash + +srun --gres=gpu:8 --nodes=4 --tasks-per-node=1 python -u -m torch.distributed.run --nproc_per_node=8 \ +--nnodes 4 --rdzv_endpoint $(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1):6000 --rdzv_backend \ +c10d all_reduce_bench.py + +To do a quick test on 2 gpus: + +python -u -m torch.distributed.run --nproc_per_node=2 --rdzv_endpoint localhost:6000 --rdzv_backend c10d \ +all_reduce_bench.py + +""" + +import os +import socket +import torch +import torch.distributed as dist + +TRIALS = 10 + +# these emulate the payload which will become a M * N * 4-sized tensor below +N = 500000 +M = 2000 + +def timed_allreduce(mat, start_event, end_event): + dist.barrier() + start_event.record() + dist.all_reduce(mat) + end_event.record() + + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) / 1000 + + n = dist.get_world_size() + size = M * N * 4 # 4 is 4 bytes in fp32 + # note that this is following the same math as NVIDIA/nccl-tests + algbw = torch.tensor([size / duration]).cuda(local_rank) + + # calculate mean across all ranks + dist.reduce(algbw, dst=0, op=dist.ReduceOp.SUM) + algbw /= n + + return algbw + +def run(local_rank): + hostname = socket.gethostname() + is_global_rank_0 = dist.get_rank() == 0 + + mat = torch.rand(N, M, dtype=torch.float32).cuda(local_rank) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # do a few warm up iterations + for i in range(2): + timed_allreduce(mat, start_event, end_event) + + # real benchmark + algbw_gather = [] + for i in range(TRIALS): + if is_global_rank_0: + print(i+1) + algbw_gather += timed_allreduce(mat, start_event, end_event) + + algbw = torch.mean(torch.stack(algbw_gather)) + + # the 2*(n-1)/n busbw correction factor specific to all-reduce is explained here: + # https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#allreduce + # busbw reflects how optimally the hardware is used + n = dist.get_world_size() + busbw = algbw * (2*(n - 1) / n) + + if is_global_rank_0: + print(f"The average bandwidth of all_reduce with a {M*N*4/1e9}GB payload ({TRIALS} trials, {n} ranks):\n", + f"algbw: {algbw/1e9:.3f} GBps ({algbw*8/1e9:.1f} Gbps)\n", + f"busbw: {busbw/1e9:.3f} GBps ({busbw*8/1e9:.1f} Gbps)\n", + ) + +def init_processes(local_rank, fn, backend='nccl'): + torch.cuda.set_device(local_rank) + dist.init_process_group(backend) + fn(local_rank) + + +if __name__ == "__main__": + local_rank = int(os.environ["LOCAL_RANK"]) + init_processes(local_rank=local_rank, fn=run) \ No newline at end of file diff --git a/utils/distributed_pytorch/benchmark/reports/Output_16f84e3c-379e-44ef-a4d0-ea595e1782a1.txt b/utils/distributed_pytorch/benchmark/reports/Output_16f84e3c-379e-44ef-a4d0-ea595e1782a1.txt new file mode 100644 index 0000000..791751f --- /dev/null +++ b/utils/distributed_pytorch/benchmark/reports/Output_16f84e3c-379e-44ef-a4d0-ea595e1782a1.txt @@ -0,0 +1,17 @@ +START TIME: Wed Apr 24 17:12:50 UTC 2024 +[2024-04-24 17:12:51,631] torch.distributed.run: [WARNING] +[2024-04-24 17:12:51,631] torch.distributed.run: [WARNING] ***************************************** +[2024-04-24 17:12:51,631] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +[2024-04-24 17:12:51,631] torch.distributed.run: [WARNING] ***************************************** +END TIME: Wed Apr 24 17:13:37 UTC 2024 + +[0:0]:6 +[0:0]:7 +[0:0]:8 +[0:0]:9 +[0:0]:10 +[0:0]:The average bandwidth of all_reduce with a 4.0GB payload (10 trials, 16 ranks): +[0:0]: algbw: 1.939 GBps (15.5 Gbps) +[0:0]: busbw: 3.635 GBps (29.1 Gbps) +[0:0]: +END TIME: Wed Apr 24 17:13:37 UTC 2024 diff --git a/utils/distributed_pytorch/benchmark/reports/Output_92f3398e-bcce-4edc-ac95-f40c9b70135d.txt b/utils/distributed_pytorch/benchmark/reports/Output_92f3398e-bcce-4edc-ac95-f40c9b70135d.txt new file mode 100644 index 0000000..98a3530 --- /dev/null +++ b/utils/distributed_pytorch/benchmark/reports/Output_92f3398e-bcce-4edc-ac95-f40c9b70135d.txt @@ -0,0 +1,17 @@ +START TIME: Wed Apr 24 17:08:28 UTC 2024 +[2024-04-24 17:08:29,989] torch.distributed.run: [WARNING] +[2024-04-24 17:08:29,989] torch.distributed.run: [WARNING] ***************************************** +[2024-04-24 17:08:29,989] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +[2024-04-24 17:08:29,989] torch.distributed.run: [WARNING] ***************************************** +END TIME: Wed Apr 24 17:08:55 UTC 2024 + +[0:0]:6 +[0:0]:7 +[0:0]:8 +[0:0]:9 +[0:0]:10 +[0:0]:The average bandwidth of all_reduce with a 4.0GB payload (10 trials, 16 ranks): +[0:0]: algbw: 8.154 GBps (65.2 Gbps) +[0:0]: busbw: 15.289 GBps (122.3 Gbps) +[0:0]: +END TIME: Wed Apr 24 17:08:55 UTC 2024 From 6032aad7c33742b2d2d1fa3e60d05cab20b278b3 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 24 Apr 2024 20:38:46 +0000 Subject: [PATCH 03/12] added my_first_distributed_app --- .../my_first_distributed_app/RCP_run_app.sh | 31 +++++++ .../my_first_distributed_app/RUNAI_run_app.sh | 11 +++ .../my_first_distributed_app.py | 82 +++++++++++++++++++ ...t_0daa767f-355c-4ab9-aa1c-094016e39a24.txt | 9 ++ 4 files changed, 133 insertions(+) create mode 100644 utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh create mode 100644 utils/distributed_pytorch/my_first_distributed_app/RUNAI_run_app.sh create mode 100644 utils/distributed_pytorch/my_first_distributed_app/my_first_distributed_app.py create mode 100644 utils/distributed_pytorch/my_first_distributed_app/reports/Output_0daa767f-355c-4ab9-aa1c-094016e39a24.txt diff --git a/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh b/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh new file mode 100644 index 0000000..3be35ad --- /dev/null +++ b/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +echo "START TIME: $(date)" + +export NCCL_IB_GID_INDEX=$(grep 'RoCE v2' $(grep '0000:0000:0000:0000:0000:ffff' /sys/class/infiniband/mlx5_bond_0/ports/1/gids/* | cut -d ':' -f 1 | sed 's/gids/gid_attrs\/types/') | sed -e 's/.*\/\([0-9]*\):.*/\1/') +export NCCL_IB_HCA=mlx5 +export NCCL_SOCKET_NTHREADS=4 +export NCCL_NSOCKS_PERTHREAD=8 + +# MASTER_ADDR -> The IP of the master node +# MASTER_PORT -> The port that of the master node +# WORLD_SIZE -> Number of nodes in total, NOT Numer of nodes X GPUs per node +PROCESSES_PER_NODE=20 + +LAUNCHER="torchrun \ + --nproc_per_node $PROCESSES_PER_NODE \ + --nnodes $WORLD_SIZE \ + --node_rank $RANK \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --role \$(hostname -s|tr -dc '0-9'): \ + --tee 3 \ + " + +PYTHON_FILE=/mloscratch/homes/solergib/getting-started/utils/distributed_pytorch/my_first_distributed_app/my_first_distributed_app.py + +export CMD="$LAUNCHER $PYTHON_FILE" +bash -c "$CMD" + +echo "END TIME: $(date)" \ No newline at end of file diff --git a/utils/distributed_pytorch/my_first_distributed_app/RUNAI_run_app.sh b/utils/distributed_pytorch/my_first_distributed_app/RUNAI_run_app.sh new file mode 100644 index 0000000..db43a13 --- /dev/null +++ b/utils/distributed_pytorch/my_first_distributed_app/RUNAI_run_app.sh @@ -0,0 +1,11 @@ +./runai-2.13.49 submit-dist pytorch \ + --name my_first_distributed_app \ + --image registry.rcp.epfl.ch/meditron-ddx/basic:latest-solergib \ + --workers 1 \ + --gpu 0 \ + --pvc mlo-scratch:/mloscratch \ + --annotation k8s.v1.cni.cncf.io/networks=kube-system/roce \ + --extended-resource rdma/rdma=1 \ + -e PATH_TO_ROOT_REPO=/mloscratch/homes/solergib/getting-started \ + --large-shm \ + -- bash -c '"source \${PATH_TO_ROOT_REPO}/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh &> \${PATH_TO_ROOT_REPO}/utils/distributed_pytorch/my_first_distributed_app/reports/Output_\${JOB_UUID}.txt"' \ No newline at end of file diff --git a/utils/distributed_pytorch/my_first_distributed_app/my_first_distributed_app.py b/utils/distributed_pytorch/my_first_distributed_app/my_first_distributed_app.py new file mode 100644 index 0000000..68f0734 --- /dev/null +++ b/utils/distributed_pytorch/my_first_distributed_app/my_first_distributed_app.py @@ -0,0 +1,82 @@ +import os + +from numba import jit +from numpy import pi +from numpy.testing import assert_almost_equal +import torch +import torch.distributed as dist + +TRIALS = 5 + +a = 0 # Lower bound +b = 1 # Upper bound +n = 1000000000 # Number of trapezoids + +@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit +def compute_partial_pi(local_a, local_b, local_n, h): + estimate = (local_a**2 + local_b**2) / 2.0 + + for i in range(local_n): + x = local_a + i * h + estimate += 4.0 / (1.0 + x*x) + + return estimate * h + +def compute_pi(local_a, local_b, local_n, h): + estimate = compute_partial_pi(local_a, local_b, local_n, h) + estimate = torch.tensor([estimate]) + dist.all_reduce(estimate, op=dist.ReduceOp.SUM) + return estimate + +def timed_compute_pi(local_a, local_b, local_n, h, start_event, end_event): + dist.barrier() + start_event.record() + # Compute partial result of pi + pi_estimate = compute_pi(local_a, local_b, local_n, h) + end_event.record() + + torch.cuda.synchronize() + assert_almost_equal(pi_estimate.item(), pi, decimal=5) + duration = start_event.elapsed_time(end_event) / 1000 + duration = torch.tensor([duration]) + + # Compute mean across all ranks + dist.reduce(duration, dst=0, op=dist.ReduceOp.SUM) + + return duration + + +def run(): + rank, world_size = int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]) # Variables set up by torchrun + + # Note: h and local_n are the same for all processes + h = (b-a)/n # Length of each trapezoid + local_n = int(n/world_size) # Number of trapezoids per process + + # Length of each process interval of integration = local_n*h. + local_a = a + rank * local_n * h + local_b = local_a + local_n * h + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Do a few warm up iterations + for _ in range(2): + timed_compute_pi(local_a, local_b, local_n, h, start_event, end_event) + + # Real benchmark + times = [] + for _ in range(TRIALS): + times += timed_compute_pi(local_a, local_b, local_n, h, start_event, end_event) + + avg_times = torch.mean(torch.stack(times)) + + if rank == 0: + print(f"Total number of processes: {world_size}") + print(f"AVG Time: {avg_times}\n") + + +if __name__ == "__main__": + # Init PyTorch process group with "gloo" backend for CPU comms (NOT NVIDIA NCCL) + dist.init_process_group(backend="gloo") + run() \ No newline at end of file diff --git a/utils/distributed_pytorch/my_first_distributed_app/reports/Output_0daa767f-355c-4ab9-aa1c-094016e39a24.txt b/utils/distributed_pytorch/my_first_distributed_app/reports/Output_0daa767f-355c-4ab9-aa1c-094016e39a24.txt new file mode 100644 index 0000000..8a33bbe --- /dev/null +++ b/utils/distributed_pytorch/my_first_distributed_app/reports/Output_0daa767f-355c-4ab9-aa1c-094016e39a24.txt @@ -0,0 +1,9 @@ +START TIME: Wed Apr 24 17:32:07 UTC 2024 +[2024-04-24 17:32:08,348] torch.distributed.run: [WARNING] +[2024-04-24 17:32:08,348] torch.distributed.run: [WARNING] ***************************************** +[2024-04-24 17:32:08,348] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +[2024-04-24 17:32:08,348] torch.distributed.run: [WARNING] ***************************************** +END TIME: Wed Apr 24 17:32:39 UTC 2024 +Total number of processes: 40 +AVG Time: 2.187171459197998 +END TIME: Wed Apr 24 17:32:39 UTC 2024 From 4de372176eb1db16aeaf5440b6e9d781dfb282fe Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Fri, 26 Apr 2024 08:46:42 +0000 Subject: [PATCH 04/12] deleted table metadata --- docs/multinode.md | 9 --------- 1 file changed, 9 deletions(-) diff --git a/docs/multinode.md b/docs/multinode.md index 6b3f431..1df5a74 100644 --- a/docs/multinode.md +++ b/docs/multinode.md @@ -140,15 +140,6 @@ Note the following: ## Inter-Node communication benchmark We conducted a benchmark to determine the bandwidth between nodes (In Gbps). As can be seen, the benefit of RDMA is significant, so it is advisable to ensure that it is enabled. - From e85569ee9f1d50904b578e03dbe6d9521187fc71 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Fri, 26 Apr 2024 08:49:16 +0000 Subject: [PATCH 05/12] fixed warning messages --- docs/multinode.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/multinode.md b/docs/multinode.md index 1df5a74..f4871ef 100644 --- a/docs/multinode.md +++ b/docs/multinode.md @@ -10,8 +10,8 @@ Jobs can be submitted either through RunAI as documented in RunAI's website (htt To execute jobs in RCP, we will use the RunAI CLI, more specifically the `submit-dist pytorch` function, which will be responsible for launching the specified command on each pod. There are two ways to execute distributed applications: 1. Interactive sessions. To force interactive sessions, we will have to launch the command `sleep infinity` on each pod. This way, we can connect to each pod, but we will have to manually execute the jobs on each one. This is useful for short sessions for debugging applications and checking that everything works correctly before launching a longer job. - > TIP - > Keep in mind that as soon as you disconnect from the pod, you will lose the current job you are executing. + > [!TIP] + > Keep in mind that as soon as you disconnect from the pod, you will lose the current job you are executing. 2. Batched execution. In this mode, we will specify to the `submit-dist` function to execute a script, and it will defer execution until the requested resources are available. This is the recommended way to launch longer jobs such as model training. To configure the number of nodes and GPUs, we will use the following flags of the `submit-dist` function: @@ -92,8 +92,8 @@ Note the following: 4. We launch the job with `bash -c "..."` to: 1. Allow for the delayed interpolation of environment variables to work (e.g., `PATH_TO_ROOT_REPO`). 2. Store the output of the job in a file. It can also be checked with `runai logs --name`, but after some time, it will be inaccessible. - > [!WARNING] - > Don't forget the double double quotes in `bash -c` (`'"..."'`). + > [!WARNING] + > Don't forget the double double quotes in `bash -c` (`'"..."'`). The script to be executed on each node is as follows ([`/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh`](/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh)): From 6c029d7b30fbed6f7e8fc497806a5d6d4f5cf0ce Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Fri, 26 Apr 2024 08:51:12 +0000 Subject: [PATCH 06/12] fixed warning messages --- docs/multinode.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/multinode.md b/docs/multinode.md index f4871ef..d111162 100644 --- a/docs/multinode.md +++ b/docs/multinode.md @@ -10,8 +10,8 @@ Jobs can be submitted either through RunAI as documented in RunAI's website (htt To execute jobs in RCP, we will use the RunAI CLI, more specifically the `submit-dist pytorch` function, which will be responsible for launching the specified command on each pod. There are two ways to execute distributed applications: 1. Interactive sessions. To force interactive sessions, we will have to launch the command `sleep infinity` on each pod. This way, we can connect to each pod, but we will have to manually execute the jobs on each one. This is useful for short sessions for debugging applications and checking that everything works correctly before launching a longer job. - > [!TIP] - > Keep in mind that as soon as you disconnect from the pod, you will lose the current job you are executing. +> [!TIP] +> Keep in mind that as soon as you disconnect from the pod, you will lose the current job you are executing. 2. Batched execution. In this mode, we will specify to the `submit-dist` function to execute a script, and it will defer execution until the requested resources are available. This is the recommended way to launch longer jobs such as model training. To configure the number of nodes and GPUs, we will use the following flags of the `submit-dist` function: @@ -92,8 +92,8 @@ Note the following: 4. We launch the job with `bash -c "..."` to: 1. Allow for the delayed interpolation of environment variables to work (e.g., `PATH_TO_ROOT_REPO`). 2. Store the output of the job in a file. It can also be checked with `runai logs --name`, but after some time, it will be inaccessible. - > [!WARNING] - > Don't forget the double double quotes in `bash -c` (`'"..."'`). +> [!WARNING] +> Don't forget the double double quotes in `bash -c` (`'"..."'`). The script to be executed on each node is as follows ([`/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh`](/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh)): @@ -134,8 +134,8 @@ echo "END TIME: $(date)" Note the following: 1. At the beginning, we set both the network and environment configurations (Activate conda environment, set environment variables, etc.). 2. To launch the distributed applications, we will use `torchrun`. In short, `torchrun` spawns `--nproc-per-node` processes on each node by executing the specified script. Additionally, it also handles communications between nodes before launching the script. For this, it is necessary to specify `MASTER_ADDR` and `MASTER_PORT`, variables that are automatically defined by RunAI when using `submit-dist pytorch`. `--nodes` will be the number of pods launched (`WORLD_SIZE`), and we will use `--node-rank` to specify the rank of each node; otherwise, `torchrun` will assign a value to each `--node-rank`. In this example, for which we will not use GPUs, we will launch 20 processes on each of the two nodes, dividing the work among a total of 40 processes. - > [!WARNING] - > Do not confuse the variables `WORLD_SIZE` and `RANK` produced by RunAI with the `submit-dist function` with the same variables generated by `torchrun` when launching the scripts. In the case of RunAI, they are configured based on the **number of pods**, while in `torchrun`, they are configured based on the **number of spawned processes**, which is defined by `--nnodes` x `--nproc-per-node`. +> [!WARNING] +> Do not confuse the variables `WORLD_SIZE` and `RANK` produced by RunAI with the `submit-dist function` with the same variables generated by `torchrun` when launching the scripts. In the case of RunAI, they are configured based on the **number of pods**, while in `torchrun`, they are configured based on the **number of spawned processes**, which is defined by `--nnodes` x `--nproc-per-node`. ## Inter-Node communication benchmark We conducted a benchmark to determine the bandwidth between nodes (In Gbps). As can be seen, the benefit of RDMA is significant, so it is advisable to ensure that it is enabled. From 85238714454e12f1159d92a6a446111c470a4c58 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Fri, 26 Apr 2024 14:47:46 +0000 Subject: [PATCH 07/12] added 4 nodes benchmark --- docs/multinode.md | 7 +++++++ ...t_1a12b610-62c8-4a66-9cc8-a540d37cc186.txt | 20 +++++++++++++++++++ ...t_e1181af9-caec-40a0-8104-50eb12f6663d.txt | 20 +++++++++++++++++++ 3 files changed, 47 insertions(+) create mode 100644 utils/distributed_pytorch/benchmark/reports/Output_1a12b610-62c8-4a66-9cc8-a540d37cc186.txt create mode 100644 utils/distributed_pytorch/benchmark/reports/Output_e1181af9-caec-40a0-8104-50eb12f6663d.txt diff --git a/docs/multinode.md b/docs/multinode.md index d111162..8f84a83 100644 --- a/docs/multinode.md +++ b/docs/multinode.md @@ -184,6 +184,13 @@ We conducted a benchmark to determine the bandwidth between nodes (In Gbps). As + + + + + + +
29.1 15.5
3276.239.330.115.6
diff --git a/utils/distributed_pytorch/benchmark/reports/Output_1a12b610-62c8-4a66-9cc8-a540d37cc186.txt b/utils/distributed_pytorch/benchmark/reports/Output_1a12b610-62c8-4a66-9cc8-a540d37cc186.txt new file mode 100644 index 0000000..8ecd55a --- /dev/null +++ b/utils/distributed_pytorch/benchmark/reports/Output_1a12b610-62c8-4a66-9cc8-a540d37cc186.txt @@ -0,0 +1,20 @@ +START TIME: Fri Apr 26 12:14:59 UTC 2024 +[2024-04-26 12:15:01,012] torch.distributed.run: [WARNING] +[2024-04-26 12:15:01,012] torch.distributed.run: [WARNING] ***************************************** +[2024-04-26 12:15:01,012] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +[2024-04-26 12:15:01,012] torch.distributed.run: [WARNING] ***************************************** +[0:0]:1 +[0:0]:2 +[0:0]:3 +[0:0]:4 +[0:0]:5 +[0:0]:6 +[0:0]:7 +[0:0]:8 +[0:0]:9 +[0:0]:10 +[0:0]:The average bandwidth of all_reduce with a 4.0GB payload (10 trials, 32 ranks): +[0:0]: algbw: 1.944 GBps (15.6 Gbps) +[0:0]: busbw: 3.767 GBps (30.1 Gbps) +[0:0]: +END TIME: Fri Apr 26 12:15:59 UTC 2024 diff --git a/utils/distributed_pytorch/benchmark/reports/Output_e1181af9-caec-40a0-8104-50eb12f6663d.txt b/utils/distributed_pytorch/benchmark/reports/Output_e1181af9-caec-40a0-8104-50eb12f6663d.txt new file mode 100644 index 0000000..04371a3 --- /dev/null +++ b/utils/distributed_pytorch/benchmark/reports/Output_e1181af9-caec-40a0-8104-50eb12f6663d.txt @@ -0,0 +1,20 @@ +START TIME: Fri Apr 26 12:06:03 UTC 2024 +[2024-04-26 12:06:04,716] torch.distributed.run: [WARNING] +[2024-04-26 12:06:04,716] torch.distributed.run: [WARNING] ***************************************** +[2024-04-26 12:06:04,716] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +[2024-04-26 12:06:04,716] torch.distributed.run: [WARNING] ***************************************** +[0:0]:1 +[0:0]:2 +[0:0]:3 +[0:0]:4 +[0:0]:5 +[0:0]:6 +[0:0]:7 +[0:0]:8 +[0:0]:9 +[0:0]:10 +[0:0]:The average bandwidth of all_reduce with a 4.0GB payload (10 trials, 32 ranks): +[0:0]: algbw: 4.914 GBps (39.3 Gbps) +[0:0]: busbw: 9.522 GBps (76.2 Gbps) +[0:0]: +END TIME: Fri Apr 26 12:06:47 UTC 2024 From 1ea65a058ec8482a3be17a4e596bfe81ef738864 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Mon, 29 Apr 2024 19:19:06 +0200 Subject: [PATCH 08/12] deleted reports --- ...t_16f84e3c-379e-44ef-a4d0-ea595e1782a1.txt | 17 ---------------- ...t_1a12b610-62c8-4a66-9cc8-a540d37cc186.txt | 20 ------------------- ...t_92f3398e-bcce-4edc-ac95-f40c9b70135d.txt | 17 ---------------- ...t_e1181af9-caec-40a0-8104-50eb12f6663d.txt | 20 ------------------- ...t_0daa767f-355c-4ab9-aa1c-094016e39a24.txt | 9 --------- 5 files changed, 83 deletions(-) delete mode 100644 utils/distributed_pytorch/benchmark/reports/Output_16f84e3c-379e-44ef-a4d0-ea595e1782a1.txt delete mode 100644 utils/distributed_pytorch/benchmark/reports/Output_1a12b610-62c8-4a66-9cc8-a540d37cc186.txt delete mode 100644 utils/distributed_pytorch/benchmark/reports/Output_92f3398e-bcce-4edc-ac95-f40c9b70135d.txt delete mode 100644 utils/distributed_pytorch/benchmark/reports/Output_e1181af9-caec-40a0-8104-50eb12f6663d.txt delete mode 100644 utils/distributed_pytorch/my_first_distributed_app/reports/Output_0daa767f-355c-4ab9-aa1c-094016e39a24.txt diff --git a/utils/distributed_pytorch/benchmark/reports/Output_16f84e3c-379e-44ef-a4d0-ea595e1782a1.txt b/utils/distributed_pytorch/benchmark/reports/Output_16f84e3c-379e-44ef-a4d0-ea595e1782a1.txt deleted file mode 100644 index 791751f..0000000 --- a/utils/distributed_pytorch/benchmark/reports/Output_16f84e3c-379e-44ef-a4d0-ea595e1782a1.txt +++ /dev/null @@ -1,17 +0,0 @@ -START TIME: Wed Apr 24 17:12:50 UTC 2024 -[2024-04-24 17:12:51,631] torch.distributed.run: [WARNING] -[2024-04-24 17:12:51,631] torch.distributed.run: [WARNING] ***************************************** -[2024-04-24 17:12:51,631] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -[2024-04-24 17:12:51,631] torch.distributed.run: [WARNING] ***************************************** -END TIME: Wed Apr 24 17:13:37 UTC 2024 - -[0:0]:6 -[0:0]:7 -[0:0]:8 -[0:0]:9 -[0:0]:10 -[0:0]:The average bandwidth of all_reduce with a 4.0GB payload (10 trials, 16 ranks): -[0:0]: algbw: 1.939 GBps (15.5 Gbps) -[0:0]: busbw: 3.635 GBps (29.1 Gbps) -[0:0]: -END TIME: Wed Apr 24 17:13:37 UTC 2024 diff --git a/utils/distributed_pytorch/benchmark/reports/Output_1a12b610-62c8-4a66-9cc8-a540d37cc186.txt b/utils/distributed_pytorch/benchmark/reports/Output_1a12b610-62c8-4a66-9cc8-a540d37cc186.txt deleted file mode 100644 index 8ecd55a..0000000 --- a/utils/distributed_pytorch/benchmark/reports/Output_1a12b610-62c8-4a66-9cc8-a540d37cc186.txt +++ /dev/null @@ -1,20 +0,0 @@ -START TIME: Fri Apr 26 12:14:59 UTC 2024 -[2024-04-26 12:15:01,012] torch.distributed.run: [WARNING] -[2024-04-26 12:15:01,012] torch.distributed.run: [WARNING] ***************************************** -[2024-04-26 12:15:01,012] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -[2024-04-26 12:15:01,012] torch.distributed.run: [WARNING] ***************************************** -[0:0]:1 -[0:0]:2 -[0:0]:3 -[0:0]:4 -[0:0]:5 -[0:0]:6 -[0:0]:7 -[0:0]:8 -[0:0]:9 -[0:0]:10 -[0:0]:The average bandwidth of all_reduce with a 4.0GB payload (10 trials, 32 ranks): -[0:0]: algbw: 1.944 GBps (15.6 Gbps) -[0:0]: busbw: 3.767 GBps (30.1 Gbps) -[0:0]: -END TIME: Fri Apr 26 12:15:59 UTC 2024 diff --git a/utils/distributed_pytorch/benchmark/reports/Output_92f3398e-bcce-4edc-ac95-f40c9b70135d.txt b/utils/distributed_pytorch/benchmark/reports/Output_92f3398e-bcce-4edc-ac95-f40c9b70135d.txt deleted file mode 100644 index 98a3530..0000000 --- a/utils/distributed_pytorch/benchmark/reports/Output_92f3398e-bcce-4edc-ac95-f40c9b70135d.txt +++ /dev/null @@ -1,17 +0,0 @@ -START TIME: Wed Apr 24 17:08:28 UTC 2024 -[2024-04-24 17:08:29,989] torch.distributed.run: [WARNING] -[2024-04-24 17:08:29,989] torch.distributed.run: [WARNING] ***************************************** -[2024-04-24 17:08:29,989] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -[2024-04-24 17:08:29,989] torch.distributed.run: [WARNING] ***************************************** -END TIME: Wed Apr 24 17:08:55 UTC 2024 - -[0:0]:6 -[0:0]:7 -[0:0]:8 -[0:0]:9 -[0:0]:10 -[0:0]:The average bandwidth of all_reduce with a 4.0GB payload (10 trials, 16 ranks): -[0:0]: algbw: 8.154 GBps (65.2 Gbps) -[0:0]: busbw: 15.289 GBps (122.3 Gbps) -[0:0]: -END TIME: Wed Apr 24 17:08:55 UTC 2024 diff --git a/utils/distributed_pytorch/benchmark/reports/Output_e1181af9-caec-40a0-8104-50eb12f6663d.txt b/utils/distributed_pytorch/benchmark/reports/Output_e1181af9-caec-40a0-8104-50eb12f6663d.txt deleted file mode 100644 index 04371a3..0000000 --- a/utils/distributed_pytorch/benchmark/reports/Output_e1181af9-caec-40a0-8104-50eb12f6663d.txt +++ /dev/null @@ -1,20 +0,0 @@ -START TIME: Fri Apr 26 12:06:03 UTC 2024 -[2024-04-26 12:06:04,716] torch.distributed.run: [WARNING] -[2024-04-26 12:06:04,716] torch.distributed.run: [WARNING] ***************************************** -[2024-04-26 12:06:04,716] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -[2024-04-26 12:06:04,716] torch.distributed.run: [WARNING] ***************************************** -[0:0]:1 -[0:0]:2 -[0:0]:3 -[0:0]:4 -[0:0]:5 -[0:0]:6 -[0:0]:7 -[0:0]:8 -[0:0]:9 -[0:0]:10 -[0:0]:The average bandwidth of all_reduce with a 4.0GB payload (10 trials, 32 ranks): -[0:0]: algbw: 4.914 GBps (39.3 Gbps) -[0:0]: busbw: 9.522 GBps (76.2 Gbps) -[0:0]: -END TIME: Fri Apr 26 12:06:47 UTC 2024 diff --git a/utils/distributed_pytorch/my_first_distributed_app/reports/Output_0daa767f-355c-4ab9-aa1c-094016e39a24.txt b/utils/distributed_pytorch/my_first_distributed_app/reports/Output_0daa767f-355c-4ab9-aa1c-094016e39a24.txt deleted file mode 100644 index 8a33bbe..0000000 --- a/utils/distributed_pytorch/my_first_distributed_app/reports/Output_0daa767f-355c-4ab9-aa1c-094016e39a24.txt +++ /dev/null @@ -1,9 +0,0 @@ -START TIME: Wed Apr 24 17:32:07 UTC 2024 -[2024-04-24 17:32:08,348] torch.distributed.run: [WARNING] -[2024-04-24 17:32:08,348] torch.distributed.run: [WARNING] ***************************************** -[2024-04-24 17:32:08,348] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -[2024-04-24 17:32:08,348] torch.distributed.run: [WARNING] ***************************************** -END TIME: Wed Apr 24 17:32:39 UTC 2024 -Total number of processes: 40 -AVG Time: 2.187171459197998 -END TIME: Wed Apr 24 17:32:39 UTC 2024 From e33820a42e1ad4d1f86ae13f2278f42d77158fc1 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Mon, 29 Apr 2024 19:19:26 +0200 Subject: [PATCH 09/12] refractor runai command --- utils/distributed_pytorch/benchmark/RUNAI_run_benchmark.sh | 2 +- .../benchmark/RUNAI_run_benchmark_no_RDMA.sh | 2 +- .../my_first_distributed_app/RUNAI_run_app.sh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/distributed_pytorch/benchmark/RUNAI_run_benchmark.sh b/utils/distributed_pytorch/benchmark/RUNAI_run_benchmark.sh index 230974b..78e5bcd 100644 --- a/utils/distributed_pytorch/benchmark/RUNAI_run_benchmark.sh +++ b/utils/distributed_pytorch/benchmark/RUNAI_run_benchmark.sh @@ -1,4 +1,4 @@ -./runai-2.13.49 submit-dist pytorch \ +runai submit-dist pytorch \ --name all-reduce-bench \ --image registry.rcp.epfl.ch/meditron-ddx/basic:latest-solergib \ --workers 3 \ diff --git a/utils/distributed_pytorch/benchmark/RUNAI_run_benchmark_no_RDMA.sh b/utils/distributed_pytorch/benchmark/RUNAI_run_benchmark_no_RDMA.sh index 948aca2..08097de 100644 --- a/utils/distributed_pytorch/benchmark/RUNAI_run_benchmark_no_RDMA.sh +++ b/utils/distributed_pytorch/benchmark/RUNAI_run_benchmark_no_RDMA.sh @@ -1,4 +1,4 @@ -./runai-2.13.49 submit-dist pytorch \ +runai submit-dist pytorch \ --name all-reduce-bench \ --image registry.rcp.epfl.ch/meditron-ddx/basic:latest-solergib \ --workers 3 \ diff --git a/utils/distributed_pytorch/my_first_distributed_app/RUNAI_run_app.sh b/utils/distributed_pytorch/my_first_distributed_app/RUNAI_run_app.sh index db43a13..cee74ed 100644 --- a/utils/distributed_pytorch/my_first_distributed_app/RUNAI_run_app.sh +++ b/utils/distributed_pytorch/my_first_distributed_app/RUNAI_run_app.sh @@ -1,4 +1,4 @@ -./runai-2.13.49 submit-dist pytorch \ +runai submit-dist pytorch \ --name my_first_distributed_app \ --image registry.rcp.epfl.ch/meditron-ddx/basic:latest-solergib \ --workers 1 \ From 2d5de3ac6fc8832d1e12ccc94af8d0efb6d1caa4 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Mon, 29 Apr 2024 19:33:00 +0200 Subject: [PATCH 10/12] Updated docs --- docs/multinode.md | 27 ++++++++++++++----- .../my_first_distributed_app.py | 2 +- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/docs/multinode.md b/docs/multinode.md index 8f84a83..d0884c3 100644 --- a/docs/multinode.md +++ b/docs/multinode.md @@ -10,14 +10,26 @@ Jobs can be submitted either through RunAI as documented in RunAI's website (htt To execute jobs in RCP, we will use the RunAI CLI, more specifically the `submit-dist pytorch` function, which will be responsible for launching the specified command on each pod. There are two ways to execute distributed applications: 1. Interactive sessions. To force interactive sessions, we will have to launch the command `sleep infinity` on each pod. This way, we can connect to each pod, but we will have to manually execute the jobs on each one. This is useful for short sessions for debugging applications and checking that everything works correctly before launching a longer job. -> [!TIP] -> Keep in mind that as soon as you disconnect from the pod, you will lose the current job you are executing. + > [!TIP] + > Keep in mind that as soon as you disconnect from the pod, you will lose the current job you are executing unless you are using applications such as tmux or screen. 2. Batched execution. In this mode, we will specify to the `submit-dist` function to execute a script, and it will defer execution until the requested resources are available. This is the recommended way to launch longer jobs such as model training. To configure the number of nodes and GPUs, we will use the following flags of the `submit-dist` function: 1. `--workers`: The total number of nodes will be `n_workers` + 1, as RunAI adds a master node by default. 2. `--gpu`: The number of GPUs per node. Unless debugging applications, set this value as the number of GPUs per node. Otherwise, it would be possible to orchestrate 2 pods on the same node, which would not make sense. +As an example, the following command launches 4 pods, each with 4 GPUs: +``` +runai submit-dist pytorch \ + --name my_first_distributed_app \ + --image registry.rcp.epfl.ch/meditron-ddx/basic:latest-solergib \ + --workers 3 \ + --gpu 0 \ + --annotation k8s.v1.cni.cncf.io/networks=kube-system/roce \ + --extended-resource rdma/rdma=1 \ + -- "sleep infinity" +``` + RunAI handles scheduling the pods and also creates the necessary communication (rendezvous) backend (most likely c10d) between them. The following environment variables are set: * `WORLD_SIZE`: Number of pods (number of GPUs in each pod does not matter.) @@ -72,7 +84,7 @@ In [`/utils/distributed_pytorch/my_first_distributed_app/`](/utils/distributed_p To launch our first application, we will use the batched execution format from the `submit-dist pytorch` function. We will launch the job as follows to distribute the work across two nodes ([`/utils/distributed_pytorch/my_first_distributed_app/RUNAI_run_app.sh`](/utils/distributed_pytorch/my_first_distributed_app/RUNAI_run_app.sh)): ``` -./runai-2.13.49 submit-dist pytorch \ +runai submit-dist pytorch \ --name my_first_distributed_app \ --image registry.rcp.epfl.ch/meditron-ddx/basic:latest-solergib \ --workers 1 \ @@ -133,9 +145,12 @@ echo "END TIME: $(date)" Note the following: 1. At the beginning, we set both the network and environment configurations (Activate conda environment, set environment variables, etc.). -2. To launch the distributed applications, we will use `torchrun`. In short, `torchrun` spawns `--nproc-per-node` processes on each node by executing the specified script. Additionally, it also handles communications between nodes before launching the script. For this, it is necessary to specify `MASTER_ADDR` and `MASTER_PORT`, variables that are automatically defined by RunAI when using `submit-dist pytorch`. `--nodes` will be the number of pods launched (`WORLD_SIZE`), and we will use `--node-rank` to specify the rank of each node; otherwise, `torchrun` will assign a value to each `--node-rank`. In this example, for which we will not use GPUs, we will launch 20 processes on each of the two nodes, dividing the work among a total of 40 processes. -> [!WARNING] -> Do not confuse the variables `WORLD_SIZE` and `RANK` produced by RunAI with the `submit-dist function` with the same variables generated by `torchrun` when launching the scripts. In the case of RunAI, they are configured based on the **number of pods**, while in `torchrun`, they are configured based on the **number of spawned processes**, which is defined by `--nnodes` x `--nproc-per-node`. +2. To launch the distributed applications, we will use `torchrun`. In short, `torchrun` will spawn `--nproc-per-node` processes running the specified python script, setting for each process the `WORLD_SIZE` and `RANK` environment variables. Additionally, it also handles communications between nodes before launching the script. For this, it is necessary to specify `MASTER_ADDR` and `MASTER_PORT`, variables that are automatically defined by RunAI when using `submit-dist pytorch`. `--nodes` will be the number of pods launched (`WORLD_SIZE`), and we will use `--node-rank` to specify the rank of each node; otherwise, `torchrun` will assign a value to each `--node-rank`. In this example, for which we will not use GPUs, we will launch 20 processes on each of the two nodes, dividing the work among a total of 40 processes. + > [!TIP] + > For applications that use GPUs, this value should be equal to the **number of GPUs per node** to maintain the 1 process per GPU relationship. + + > [!WARNING] + > Do not confuse the variables `WORLD_SIZE` and `RANK` produced by RunAI with the `submit-dist function` with the same variables generated by `torchrun` when launching the scripts. In the case of RunAI, they are configured based on the **number of pods**, while in `torchrun`, they are configured based on the **number of spawned processes**, which is defined by `--nnodes` x `--nproc-per-node`. ## Inter-Node communication benchmark We conducted a benchmark to determine the bandwidth between nodes (In Gbps). As can be seen, the benefit of RDMA is significant, so it is advisable to ensure that it is enabled. diff --git a/utils/distributed_pytorch/my_first_distributed_app/my_first_distributed_app.py b/utils/distributed_pytorch/my_first_distributed_app/my_first_distributed_app.py index 68f0734..7d4f4d8 100644 --- a/utils/distributed_pytorch/my_first_distributed_app/my_first_distributed_app.py +++ b/utils/distributed_pytorch/my_first_distributed_app/my_first_distributed_app.py @@ -40,7 +40,7 @@ def timed_compute_pi(local_a, local_b, local_n, h, start_event, end_event): duration = start_event.elapsed_time(end_event) / 1000 duration = torch.tensor([duration]) - # Compute mean across all ranks + # Compute mean ONLY in rank 0 dist.reduce(duration, dst=0, op=dist.ReduceOp.SUM) return duration From c067b9ec3aa2f4d1008c58a464785e5e6ff496a0 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Mon, 29 Apr 2024 19:52:20 +0200 Subject: [PATCH 11/12] added relative paths --- utils/distributed_pytorch/benchmark/RCP_run_benchmark.sh | 8 +++----- .../benchmark/RCP_run_benchmark_no_RDMA.sh | 8 +++----- .../my_first_distributed_app/RCP_run_app.sh | 4 +++- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/utils/distributed_pytorch/benchmark/RCP_run_benchmark.sh b/utils/distributed_pytorch/benchmark/RCP_run_benchmark.sh index 10749ff..1764710 100644 --- a/utils/distributed_pytorch/benchmark/RCP_run_benchmark.sh +++ b/utils/distributed_pytorch/benchmark/RCP_run_benchmark.sh @@ -1,5 +1,7 @@ #!/bin/bash +cd $PATH_TO_ROOT_REPO + echo "START TIME: $(date)" export NCCL_IB_GID_INDEX=$(grep 'RoCE v2' $(grep '0000:0000:0000:0000:0000:ffff' /sys/class/infiniband/mlx5_bond_0/ports/1/gids/* | cut -d ':' -f 1 | sed 's/gids/gid_attrs\/types/') | sed -e 's/.*\/\([0-9]*\):.*/\1/') @@ -23,11 +25,7 @@ LAUNCHER="torchrun \ --tee 3 \ " -PYTHON_FILE=/mloscratch/homes/solergib/getting-started/utils/distributed_pytorch/benchmark/all_reduce_bench.py -PYTHON_ARGS=" \ - --batch_size 16 \ - --model Llama3 \ - " +PYTHON_FILE=utils/distributed_pytorch/benchmark/all_reduce_bench.py export CMD="$LAUNCHER $PYTHON_FILE $PYTHON_ARGS" bash -c "$CMD" diff --git a/utils/distributed_pytorch/benchmark/RCP_run_benchmark_no_RDMA.sh b/utils/distributed_pytorch/benchmark/RCP_run_benchmark_no_RDMA.sh index f8e3cd7..3b805a0 100644 --- a/utils/distributed_pytorch/benchmark/RCP_run_benchmark_no_RDMA.sh +++ b/utils/distributed_pytorch/benchmark/RCP_run_benchmark_no_RDMA.sh @@ -1,5 +1,7 @@ #!/bin/bash +cd $PATH_TO_ROOT_REPO + echo "START TIME: $(date)" # MASTER_ADDR -> The IP of the master node @@ -18,11 +20,7 @@ LAUNCHER="torchrun \ --tee 3 \ " -PYTHON_FILE=/mloscratch/homes/solergib/getting-started/utils/distributed_pytorch/benchmark/all_reduce_bench.py -PYTHON_ARGS=" \ - --batch_size 16 \ - --model Llama3 \ - " +PYTHON_FILE=utils/distributed_pytorch/benchmark/all_reduce_bench.py export CMD="$LAUNCHER $PYTHON_FILE $PYTHON_ARGS" bash -c "$CMD" diff --git a/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh b/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh index 3be35ad..b256204 100644 --- a/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh +++ b/utils/distributed_pytorch/my_first_distributed_app/RCP_run_app.sh @@ -1,5 +1,7 @@ #!/bin/bash +cd $PATH_TO_ROOT_REPO + echo "START TIME: $(date)" export NCCL_IB_GID_INDEX=$(grep 'RoCE v2' $(grep '0000:0000:0000:0000:0000:ffff' /sys/class/infiniband/mlx5_bond_0/ports/1/gids/* | cut -d ':' -f 1 | sed 's/gids/gid_attrs\/types/') | sed -e 's/.*\/\([0-9]*\):.*/\1/') @@ -23,7 +25,7 @@ LAUNCHER="torchrun \ --tee 3 \ " -PYTHON_FILE=/mloscratch/homes/solergib/getting-started/utils/distributed_pytorch/my_first_distributed_app/my_first_distributed_app.py +PYTHON_FILE=utils/distributed_pytorch/my_first_distributed_app/my_first_distributed_app.py export CMD="$LAUNCHER $PYTHON_FILE" bash -c "$CMD" From f712e38852e03c429c0c68dd57a49a5c0de3e355 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 30 Apr 2024 00:55:12 +0200 Subject: [PATCH 12/12] doc typo --- docs/multinode.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/multinode.md b/docs/multinode.md index d0884c3..156a747 100644 --- a/docs/multinode.md +++ b/docs/multinode.md @@ -10,8 +10,8 @@ Jobs can be submitted either through RunAI as documented in RunAI's website (htt To execute jobs in RCP, we will use the RunAI CLI, more specifically the `submit-dist pytorch` function, which will be responsible for launching the specified command on each pod. There are two ways to execute distributed applications: 1. Interactive sessions. To force interactive sessions, we will have to launch the command `sleep infinity` on each pod. This way, we can connect to each pod, but we will have to manually execute the jobs on each one. This is useful for short sessions for debugging applications and checking that everything works correctly before launching a longer job. - > [!TIP] - > Keep in mind that as soon as you disconnect from the pod, you will lose the current job you are executing unless you are using applications such as tmux or screen. +> [!TIP] +> Keep in mind that as soon as you disconnect from the pod, you will lose the current job you are executing unless you are using applications such as tmux or screen. 2. Batched execution. In this mode, we will specify to the `submit-dist` function to execute a script, and it will defer execution until the requested resources are available. This is the recommended way to launch longer jobs such as model training. To configure the number of nodes and GPUs, we will use the following flags of the `submit-dist` function: @@ -146,11 +146,11 @@ echo "END TIME: $(date)" Note the following: 1. At the beginning, we set both the network and environment configurations (Activate conda environment, set environment variables, etc.). 2. To launch the distributed applications, we will use `torchrun`. In short, `torchrun` will spawn `--nproc-per-node` processes running the specified python script, setting for each process the `WORLD_SIZE` and `RANK` environment variables. Additionally, it also handles communications between nodes before launching the script. For this, it is necessary to specify `MASTER_ADDR` and `MASTER_PORT`, variables that are automatically defined by RunAI when using `submit-dist pytorch`. `--nodes` will be the number of pods launched (`WORLD_SIZE`), and we will use `--node-rank` to specify the rank of each node; otherwise, `torchrun` will assign a value to each `--node-rank`. In this example, for which we will not use GPUs, we will launch 20 processes on each of the two nodes, dividing the work among a total of 40 processes. - > [!TIP] - > For applications that use GPUs, this value should be equal to the **number of GPUs per node** to maintain the 1 process per GPU relationship. +> [!TIP] +> For applications that use GPUs, this value should be equal to the **number of GPUs per node** to maintain the 1 process per GPU relationship. - > [!WARNING] - > Do not confuse the variables `WORLD_SIZE` and `RANK` produced by RunAI with the `submit-dist function` with the same variables generated by `torchrun` when launching the scripts. In the case of RunAI, they are configured based on the **number of pods**, while in `torchrun`, they are configured based on the **number of spawned processes**, which is defined by `--nnodes` x `--nproc-per-node`. +> [!WARNING] +> Do not confuse the variables `WORLD_SIZE` and `RANK` produced by RunAI with the `submit-dist function` with the same variables generated by `torchrun` when launching the scripts. In the case of RunAI, they are configured based on the **number of pods**, while in `torchrun`, they are configured based on the **number of spawned processes**, which is defined by `--nnodes` x `--nproc-per-node`. ## Inter-Node communication benchmark We conducted a benchmark to determine the bandwidth between nodes (In Gbps). As can be seen, the benefit of RDMA is significant, so it is advisable to ensure that it is enabled.