diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index ad59173d5..3fde40e87 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -57,16 +57,22 @@ jobs: uses: ./.github/workflows/_build_container.yml Unit_Tests: + name: ${{ matrix.test_case }} needs: [build-container, pre-flight] uses: ./.github/workflows/_run_test.yml if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'unit') || needs.pre-flight.outputs.all == 'true' + strategy: + matrix: + test_case: + - run_unit.sh + - run_mpi_unit.sh with: RUNNER: self-hosted-azure TIMEOUT: 10 SCRIPT: | nvidia-smi cd ${ALIGNER_REPO_DIR} - bash tests/run_unit.sh + bash tests/${{ matrix.test_case }} Functional_Tests: name: ${{ matrix.test_case }} @@ -76,15 +82,12 @@ jobs: strategy: matrix: test_case: - #- ppo-pp-llama3 + - ppo-llama3-pp2-reshard - dpo-llama3 + with: RUNNER: self-hosted-azure # Fairly aggresive timeout that all functional tests should try to adhere to - TIMEOUT: 10 + TIMEOUT: 8 SCRIPT: | - export PYTHONPATH=${ALIGNER_REPO_DIR}:${PYTHONPATH:-} - nvidia-smi - git config --global --add safe.directory ${ALIGNER_REPO_DIR} - cd ${ALIGNER_REPO_DIR} - bash tests/functional/test_cases/${{ matrix.test_case }}.sh + bash /opt/NeMo-Aligner/tests/functional/test_cases/${{ matrix.test_case }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 564f0593d..29a752e3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,16 +11,42 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Breaking Changes ### Bug Fixes + +### Deprecation Notices --> ## [Next Version] ### New Features and Optimizations - Added support for Megatron’s distributed optimizer, which can be configured using `++model.optim.name=mcore_distributed_optim`. +- Introduced `ScopedTimer` as a successor to `SyncedTimer`. `SyncedTimer` is marked for deprecation and will be removed in the next version. + ```python + from nemo_aligner.utils.distributed import ScopedTimer + timer = ScopedTimer() + + # All durations are logged in the timer + with timer("step_time"): + with timer("fwd"): + model.fwd() + with timer("bwd"): + model.bwd() + + # Consume all durations and reset internal store + durations = timer.consume_durations() + ``` ### Breaking Changes +- Upgrade TRTLLM dependency from v0.10.0 to v0.12.0 and migrate from `GPTSession` cpp runtime to `ModelRunner` python runtime. Please use the latest Dockerfile. +- Using latest TransformerEngine versions may require `++model.dist_ckpt_load_strictness=log_all` when loading from a older pre-existing checkpoint to not error out. +- NeMo-Aligner now requires Megatron-LM==0.9.0 for the APIs to calculate the microbatch sizes (API introduced `megatron.core.num_microbatches_calculator.reconfigure_num_microbatch_calculator`). +- NeMo-Aligner now requires a version of NeMo with this change to how the MoE spec is handled: https://github.com/NVIDIA/NeMo/pull/9035 . ### Bug Fixes +- It is now required, for stability, to add `export NCCL_ALGO=...` to scripts launching PPO training loop. Please see the [RLHF docs](./docs/user-guide/rlhf.rst) for information. + +### Deprecation Notices +- `SyncedTimer` is marked for deprecation and will be removed in `0.7.0`. Please switch to `ScopedTimer` +- `broadcast_2d_tensor` and `broadcast_2d_tensor_within_pp` is marked for deprecation and will be removed in `0.7.0`. Please switch to `broadcast_tensor` and `broadcast_tensor_within_pp`. ## NVIDIA NeMo-Aligner 0.5.0 @@ -32,6 +58,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Bug Fixes - Change `log_prob_forward_micro_batch_size` in DPO to mean the same as the `micro_batch_size`, which is how many samples(chosen and rejected included) that we process at once. +- PPO TensorRT-LLM acceleration now no longer errors if using a tokenizer without a `pad_id`. Examples being llama3 and llama3.1 tokenizers from huggingface. ## NVIDIA NeMo-Aligner 0.4.0 - Implement reward-aware preference optimization. @@ -51,7 +78,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Breaking Changes - `inference.micro_batch_size` is now renamed to `inference.inference_micro_batch_size` when running reward model inference in `inference_rm.yaml`. This is to stay consistent with the naming scheme of the PPO critic. - It is no longer possible to specify `add_EOS` when running reward model or critic inference. -- NeMo-Aligner now requires Megatron-LM>=0.8.0 for the APIs to calculate the microbatch sizes. +- NeMo-Aligner now requires Megatron-LM==0.8.0 for the APIs to calculate the microbatch sizes (API introduced `megatron.core.num_microbatches_calculator.reconfigure_microbatch_calculator`). ### Bug Fixes - Make `num_workers` for dataloaders 0 by default. This prevents issues when using MPI (with TRT-LLM) or more sophisticated launchers. diff --git a/Dockerfile b/Dockerfile index 9949d4636..c94cadc68 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ # # To update NeMo-Aligner from a pre-built NeMo-Framework container: # -# docker buildx build --target=aligner-bump --build-arg=BASE_IMAGE=nvcr.io/nvidia/nemo:24.07 -t aligner:latest . +# docker buildx build --target=aligner-bump -t aligner:latest . # # Number of parallel threads for compute heavy build jobs @@ -13,13 +13,12 @@ ARG MAX_JOBS=8 # Git refs for dependencies ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG PYTRITON_VERSION=0.5.10 -ARG NEMO_TAG=e033481e26e6ae32764d3e2b3f16afed00dc7218 # On: r2.0.0rc1 -ARG MLM_TAG=a3fe0c75df82218901fa2c3a7c9e389aa5f53182 # On: core_r0.8.0 +ARG NEMO_TAG=19668e5320a2e2af0199b6d5e0b841993be3a634 # On: main +ARG MLM_TAG=25059d3bbf68be0751800f3644731df12a88f3f3 # On: main ARG ALIGNER_COMMIT=main -ARG TRTLLM_VERSION=v0.10.0 +ARG TRTLLM_VERSION=v0.13.0 ARG PROTOBUF_VERSION=4.24.4 - -ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.03-py3 +ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.07-py3 FROM ${BASE_IMAGE} AS aligner-bump ARG ALIGNER_COMMIT @@ -36,13 +35,40 @@ git checkout -f $ALIGNER_COMMIT # case 2: ALIGNER_COMMIT is a commit, so git-pull is expected to fail git pull --rebase || true -pip install --no-deps -e . +pip install --no-cache-dir --no-deps -e . EOF FROM ${BASE_IMAGE} as final WORKDIR /opt # needed in case git complains that it can't detect a valid email, this email is fake but works RUN git config --global user.email "worker@nvidia.com" +# install latest apex +ARG APEX_TAG +RUN pip uninstall -y apex && \ + git clone https://github.com/NVIDIA/apex && \ + cd apex && \ + if [ ! -z $APEX_TAG ]; then \ + git fetch origin $APEX_TAG && \ + git checkout FETCH_HEAD; \ + fi && \ + pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ + +# Git LFS +RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \ + apt-get install git-lfs && \ + git lfs install && \ + apt-get clean + +# TRTLLM +ARG TRTLLM_VERSION +RUN git clone https://github.com/NVIDIA/TensorRT-LLM.git && \ + cd TensorRT-LLM && \ + git checkout ${TRTLLM_VERSION} && \ + . docker/common/install_tensorrt.sh && \ + python3 ./scripts/build_wheel.py --job_count $(nproc) --trt_root /usr/local/tensorrt --python_bindings --benchmarks && \ + pip install -e . +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12/compat/lib.real/ + # install TransformerEngine ARG MAX_JOBS ARG TE_TAG @@ -56,17 +82,6 @@ RUN pip uninstall -y transformer-engine && \ git submodule init && git submodule update && \ NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install . -# install latest apex -ARG APEX_TAG -RUN pip uninstall -y apex && \ - git clone https://github.com/NVIDIA/apex && \ - cd apex && \ - if [ ! -z $APEX_TAG ]; then \ - git fetch origin $APEX_TAG && \ - git checkout FETCH_HEAD; \ - fi && \ - pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ - # place any util pkgs here ARG PYTRITON_VERSION RUN pip install --upgrade-strategy only-if-needed nvidia-pytriton==$PYTRITON_VERSION @@ -99,29 +114,32 @@ RUN pip uninstall -y megatron-core && \ fi && \ pip install -e . -# Git LFS -RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \ - apt-get install git-lfs && \ - git lfs install - COPY --from=aligner-bump /opt/NeMo-Aligner /opt/NeMo-Aligner RUN cd /opt/NeMo-Aligner && \ pip install --no-deps -e . -# TRTLLM -ARG TRTLLM_VERSION -RUN git clone https://github.com/NVIDIA/TensorRT-LLM.git && \ - cd TensorRT-LLM && \ - git checkout ${TRTLLM_VERSION} && \ - patch -p1 < ../NeMo-Aligner/setup/trtllm.patch && \ - . docker/common/install_tensorrt.sh && \ - python3 ./scripts/build_wheel.py --trt_root /usr/local/tensorrt - -RUN cd TensorRT-LLM && \ - pip install ./build/tensorrt_llm*.whl -ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12/compat/lib.real/ +RUN cd TensorRT-LLM && patch -p1 < ../NeMo-Aligner/setup/trtllm.patch -# WAR(0.4.0): The pin of NeMo requires a higher nvidia-modelopt version than -# TRT-LLM allows. This installation must follow TRT-LLM and is -# only necessary when NeMo 2.0.0rc1 is installed with TRT-LLM v10. -RUN pip install --upgrade-strategy only-if-needed nvidia-modelopt==0.13.0 +# TODO(terryk): This layer should be deleted ASAP after NeMo is bumped to include all of these PRs +RUN <<"EOF" bash -exu +cd NeMo +# Ensures we don't cherry-pick "future" origin/main commits +git fetch -a +# 0c92fe17df4642ffc33d5d8c0c83fda729e3910c: [fix] Ensures disabling exp_manager with exp_manager=null does not error NeMo#10651 +# 60e677423667c029dd05875da72bf0719774f844: [feat] Update get_model_parallel_src_rank to support tp-pp-dp ordering NeMo#10652 +# 0deaf6716cb4f20766c995ce25d129795f1ae200: fix[export]: update API for disabling device reassignment in TRTLLM for Aligner NeMo#10863 +# (superceded by 10863) 148543d6e9c66ff1f8562e84484448202249811d: feat: Migrate GPTSession refit path in Nemo export to ModelRunner for Aligner NeMo#10654 +for pr_and_commit in \ + "10651 0c92fe17df4642ffc33d5d8c0c83fda729e3910c" \ + "10652 60e677423667c029dd05875da72bf0719774f844" \ + "10863 0deaf6716cb4f20766c995ce25d129795f1ae200" \ +; do + pr=$(cut -f1 -d' ' <<<"$pr_and_commit") + head_pr_commit=$(cut -f2 -d' ' <<<"$pr_and_commit") + git fetch origin $head_pr_commit:PR-${pr} + # cherry-picks all commits between main and the top of the PR + git cherry-pick --allow-empty $(git merge-base origin/main PR-${pr})..PR-${pr} + # Tag cherry-picks to help + git tag cherry-pick-PR-${pr} +done +EOF diff --git a/docs/user-guide/dpo.rst b/docs/user-guide/dpo.rst index d5c39a814..a5d696548 100644 --- a/docs/user-guide/dpo.rst +++ b/docs/user-guide/dpo.rst @@ -184,7 +184,7 @@ For the following parameters, the ``model.dpo.ref_policy_kl_penalty`` correspond ++model.dpo.ref_policy_kl_penalty=0.1 EOF - srun -o $OUTFILE -e $ERRFILE --container-image=$CONTAINER $MOUNTS bash -c "${cmd}" + srun --no-container-mount-home -o $OUTFILE -e $ERRFILE --container-image=$CONTAINER $MOUNTS bash -c "${cmd}" set +x The default DPO training tunes all parameters. To use LoRA, we can set ``model.peft.peft_scheme=lora`` and use different parameters in ``model.peft.lora_tuning``. Please check the parameters in `the config file `__. diff --git a/docs/user-guide/draftp.rst b/docs/user-guide/draftp.rst index d4e2fc9fc..165eae33a 100644 --- a/docs/user-guide/draftp.rst +++ b/docs/user-guide/draftp.rst @@ -164,7 +164,7 @@ To start reward model training, you need checkpoints for both the `UNet > -p <> --job-name <> -t 4:00:00 --exclusive + # To ensure determinism when calculating log probabilities between two forward-passes with identical weights, it is strongly + # recommended to set NCCL_ALGO. See https://github.com/NVIDIA/Megatron-LM/blob/b3375a0e38c10e2300ef4be031f7dcabab52b448/megatron/training/arguments.py#L593-L595 + # for options. + export NCCL_ALGO=Tree + NAME="2p_ppo" # PARAMETERS @@ -305,7 +310,7 @@ You can use Slurm to launch both jobs and coordinate them together in a full RLH pretrained_checkpoint.restore_from_path=${RM_NEMO_FILE} EOF - srun --het-group=0 -o $CRITIC_OUTFILE -e $CRITIC_ERRFILE --container-image=${CONTAINER} $MOUNTS bash -c "${cmd_critic_inference}" & + srun --no-container-mount-home --het-group=0 -o $CRITIC_OUTFILE -e $CRITIC_ERRFILE --container-image=${CONTAINER} $MOUNTS bash -c "${cmd_critic_inference}" & sleep 30 @@ -356,7 +361,7 @@ You can use Slurm to launch both jobs and coordinate them together in a full RLH remote_critic_rm.critic.port=${CRITIC_PORT} EOF - srun --het-group=1 -o $PPO_OUTFILE -e $PPO_ERRFILE --container-image=${CONTAINER} $MOUNTS bash -c "${cmd_ppo}" & + srun --no-container-mount-home --het-group=1 -o $PPO_OUTFILE -e $PPO_ERRFILE --container-image=${CONTAINER} $MOUNTS bash -c "${cmd_ppo}" & wait diff --git a/docs/user-guide/rs.rst b/docs/user-guide/rs.rst index 44862eeaa..0b2c8d079 100644 --- a/docs/user-guide/rs.rst +++ b/docs/user-guide/rs.rst @@ -160,7 +160,7 @@ You can use Slurm to launch the two jobs and get them to coordinate together in inference.port=${CRITIC_PORT} EOF - srun --het-group=0 -o $CRITIC_OUTFILE -e $CRITIC_ERRFILE --container-image=${CONTAINER} $MOUNTS bash -c "${cmd_critic_inference}" & + srun --no-container-mount-home --het-group=0 -o $CRITIC_OUTFILE -e $CRITIC_ERRFILE --container-image=${CONTAINER} $MOUNTS bash -c "${cmd_critic_inference}" & sleep 30 @@ -216,7 +216,7 @@ You can use Slurm to launch the two jobs and get them to coordinate together in model.rs.top_n_rollouts=1 EOF - srun --het-group=1 -o $PPO_OUTFILE -e $PPO_ERRFILE --container-image=${CONTAINER} $MOUNTS bash -c "${cmd_rs}" & + srun --no-container-mount-home --het-group=1 -o $PPO_OUTFILE -e $PPO_ERRFILE --container-image=${CONTAINER} $MOUNTS bash -c "${cmd_rs}" & wait diff --git a/docs/user-guide/sft.rst b/docs/user-guide/sft.rst index d6beed8d6..335710aed 100644 --- a/docs/user-guide/sft.rst +++ b/docs/user-guide/sft.rst @@ -227,7 +227,7 @@ Now, you will use the data for supervised fine-tuning with NeMo-Aligner. exp_manager.checkpoint_callback_params.monitor=val_loss EOF - srun -o $OUTFILE -e $ERRFILE --container-image=$CONTAINER $MOUNTS bash -c "${cmd}" + srun --no-container-mount-home -o $OUTFILE -e $ERRFILE --container-image=$CONTAINER $MOUNTS bash -c "${cmd}" set +x If using sequence packing, replace the data paths with the paths to your packed datasets. For each packed dataset, you should also set ``packed_sequence=True`` in the config: @@ -391,7 +391,7 @@ Now, you will use the data for supervised fine-tuning with NeMo-Aligner. Compare exp_manager.checkpoint_callback_params.monitor=validation_loss EOF - srun -o $OUTFILE -e $ERRFILE --container-image=$CONTAINER $MOUNTS bash -c "${cmd}" + srun --no-container-mount-home -o $OUTFILE -e $ERRFILE --container-image=$CONTAINER $MOUNTS bash -c "${cmd}" set +x diff --git a/docs/user-guide/spin.rst b/docs/user-guide/spin.rst index 67b9f8089..613886e53 100644 --- a/docs/user-guide/spin.rst +++ b/docs/user-guide/spin.rst @@ -165,7 +165,7 @@ For the following parameters, the ``model.spin.ref_policy_kl_penalty`` correspon model.data.train_ds.max_seq_length=4096 EOF - srun -o $OUTFILE -e $ERRFILE --container-image=$CONTAINER $MOUNTS bash -c "${cmd}" + srun --no-container-mount-home -o $OUTFILE -e $ERRFILE --container-image=$CONTAINER $MOUNTS bash -c "${cmd}" set +x During SPIN training, several metrics will be recorded to WandB for you to monitor, chiefly acc (representing the percentage by which the model's implicit reward for the ground truth response exceeds that of the response generated by the reference policy). diff --git a/examples/mm/stable_diffusion/train_sd_draftp.py b/examples/mm/stable_diffusion/train_sd_draftp.py index bf732a393..367332842 100644 --- a/examples/mm/stable_diffusion/train_sd_draftp.py +++ b/examples/mm/stable_diffusion/train_sd_draftp.py @@ -120,7 +120,7 @@ def main(cfg) -> None: ptl_model.reward_model = reward_model ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) - timer = Timer(cfg.exp_manager.get("max_time_per_run", "0:12:00:00")) + timer = Timer(cfg.exp_manager.get("max_time_per_run") if cfg.exp_manager else None) draft_p_trainer = SupervisedTrainer( cfg=cfg.trainer.draftp_sd, diff --git a/examples/mm/stable_diffusion/train_sdxl_draftp.py b/examples/mm/stable_diffusion/train_sdxl_draftp.py index e4a36e1e7..d75177fbb 100644 --- a/examples/mm/stable_diffusion/train_sdxl_draftp.py +++ b/examples/mm/stable_diffusion/train_sdxl_draftp.py @@ -243,7 +243,7 @@ def checkpoint_check_fn(module): torch.distributed.barrier() ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) - timer = Timer(cfg.exp_manager.get("max_time_per_run", "0:24:00:00")) + timer = Timer(cfg.exp_manager.get("max_time_per_run") if cfg.exp_manager else None) draft_p_trainer = SupervisedTrainer( cfg=cfg.trainer.draftp_sd, diff --git a/examples/nlp/gpt/conf/gpt_ppo_actor.yaml b/examples/nlp/gpt/conf/gpt_ppo_actor.yaml index efe836e96..e0a5a1045 100644 --- a/examples/nlp/gpt/conf/gpt_ppo_actor.yaml +++ b/examples/nlp/gpt/conf/gpt_ppo_actor.yaml @@ -45,10 +45,8 @@ trainer: enable: False reshard: False # if True then reshard the model into TP only for inference - # TRTLLM preallocates activation memory according to the number of input tokens # By default, assume the max input length is the difference between the model sequence length and the max number of tokens to generate max_input_len: ${subtract:${model.encoder_seq_length}, ${model.ppo.length_params.max_length}} - max_input_tokens: ${multiply:${.max_input_len}, ${model.ppo.rollout_micro_batch_size}} # the seed to use for trt-llm generation seed: ${model.seed} @@ -73,6 +71,9 @@ trainer: model_gbs: ${model.global_batch_size} model_mbs: ${model.micro_batch_size} + # Default set to an ephemeral location within the container + trt_model_dir: /tmp/trt_llm_model + # no need to change these logger: False # logger provided by exp_manager enable_checkpointing: False @@ -130,6 +131,8 @@ pretrained_checkpoint: restore_from_path: null model: + # TODO: document perf implications + # use_tp_pp_dp_mapping: true ppo: # training generation mbs @@ -174,7 +177,6 @@ model: trt_llm: ${trainer.ppo.trt_llm} - #peft peft: peft_scheme: "none" # ["lora", "none"] restore_from_path: null diff --git a/examples/nlp/gpt/train_gpt_dpo.py b/examples/nlp/gpt/train_gpt_dpo.py index aefa0c5ac..49784d485 100644 --- a/examples/nlp/gpt/train_gpt_dpo.py +++ b/examples/nlp/gpt/train_gpt_dpo.py @@ -138,7 +138,7 @@ def main(cfg) -> None: logger.log_hyperparams(OmegaConf.to_container(cfg)) - timer = Timer(cfg.exp_manager.get("max_time_per_run")) + timer = Timer(cfg.exp_manager.get("max_time_per_run") if cfg.exp_manager else None) dpo_trainer = DPOTrainer( cfg=cfg.trainer.dpo, model=ptl_model, diff --git a/examples/nlp/gpt/train_gpt_ppo_actor.py b/examples/nlp/gpt/train_gpt_ppo_actor.py index d0bf49aa8..2ad5c72fc 100644 --- a/examples/nlp/gpt/train_gpt_ppo_actor.py +++ b/examples/nlp/gpt/train_gpt_ppo_actor.py @@ -159,7 +159,7 @@ def main(cfg) -> None: logger.log_hyperparams(OmegaConf.to_container(cfg)) rm_critic = RemoteGPTRMCriticClient(cfg.remote_critic_rm) - timer = Timer(cfg.exp_manager.get("max_time_per_run")) + timer = Timer(cfg.exp_manager.get("max_time_per_run") if cfg.exp_manager else None) batch_iterator_cfg = cfg.trainer.ppo.get("batch_iterator", {}) batch_iterator_cls = get_batch_iterator_cls(batch_iterator_cfg) diff --git a/examples/nlp/gpt/train_gpt_sft.py b/examples/nlp/gpt/train_gpt_sft.py index dd4c48ba4..371c0f5aa 100644 --- a/examples/nlp/gpt/train_gpt_sft.py +++ b/examples/nlp/gpt/train_gpt_sft.py @@ -221,7 +221,7 @@ def main(cfg) -> None: ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) logger.log_hyperparams(OmegaConf.to_container(cfg)) - timer = Timer(cfg.exp_manager.get("max_time_per_run")) + timer = Timer(cfg.exp_manager.get("max_time_per_run") if cfg.exp_manager else None) sft_trainer = SupervisedTrainer( cfg=cfg.trainer.sft, diff --git a/examples/nlp/gpt/train_gpt_spin.py b/examples/nlp/gpt/train_gpt_spin.py index c33bf96f5..af95aaf2f 100644 --- a/examples/nlp/gpt/train_gpt_spin.py +++ b/examples/nlp/gpt/train_gpt_spin.py @@ -164,7 +164,7 @@ def main(cfg) -> None: ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) logger.log_hyperparams(OmegaConf.to_container(cfg)) - timer = Timer(cfg.exp_manager.get("max_time_per_run")) + timer = Timer(cfg.exp_manager.get("max_time_per_run") if cfg.exp_manager else None) spin_trainer = SPINTrainer( cfg=cfg.trainer.spin, diff --git a/examples/nlp/gpt/train_reward_model.py b/examples/nlp/gpt/train_reward_model.py index 65af15603..35b237403 100644 --- a/examples/nlp/gpt/train_reward_model.py +++ b/examples/nlp/gpt/train_reward_model.py @@ -132,7 +132,7 @@ def main(cfg) -> None: logger.log_hyperparams(OmegaConf.to_container(cfg)) - timer = Timer(cfg.exp_manager.get("max_time_per_run")) + timer = Timer(cfg.exp_manager.get("max_time_per_run") if cfg.exp_manager else None) rm_trainer = SupervisedTrainer( cfg=cfg.trainer.rm, diff --git a/examples/nlp/gpt/train_steerlm2.py b/examples/nlp/gpt/train_steerlm2.py index 62a01b3f4..d52dea8cf 100644 --- a/examples/nlp/gpt/train_steerlm2.py +++ b/examples/nlp/gpt/train_steerlm2.py @@ -259,7 +259,7 @@ def main(cfg) -> None: ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) logger.log_hyperparams(OmegaConf.to_container(cfg)) - timer = Timer(cfg.exp_manager.get("max_time_per_run")) + timer = Timer(cfg.exp_manager.get("max_time_per_run") if cfg.exp_manager else None) sft_trainer = SupervisedTrainer( cfg=cfg.trainer.sft, diff --git a/nemo_aligner/__init__.py b/nemo_aligner/__init__.py index 8b54b34cb..4df0c0dc8 100644 --- a/nemo_aligner/__init__.py +++ b/nemo_aligner/__init__.py @@ -12,6 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +from nemo.utils import logging + +os.environ["DISABLE_TORCH_DEVICE_SET"] = "1" +logging.info( + f"Importing NeMo-Aligner sets DISABLE_TORCH_DEVICE_SET=1 to disable device reassignment within TensorRT-LLM" +) + + from nemo_aligner.package_info import ( __contact_emails__, __contact_names__, diff --git a/nemo_aligner/algorithms/ppo.py b/nemo_aligner/algorithms/ppo.py index afbfd2180..323c18224 100644 --- a/nemo_aligner/algorithms/ppo.py +++ b/nemo_aligner/algorithms/ppo.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools from collections import UserDict from contextlib import nullcontext -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional import pandas as pd import torch @@ -28,9 +27,10 @@ from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import MegatronPretrainingRandomSampler from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split from nemo.utils import logging +from nemo_aligner.models.nlp.gpt.megatron_gpt_ppo_actor import MegatronGPTActorModel from nemo_aligner.utils import parallel_state from nemo_aligner.utils.distributed import ( - SyncTimer, + ScopedTimer, all_reduce_dict, masked_global_mean_var, normalize_tensor, @@ -134,7 +134,7 @@ class PPOTrainer: def __init__( self, cfg: DictConfig, - model, + model: MegatronGPTActorModel, optimizer, scheduler, train_dataloader_builder, @@ -194,9 +194,7 @@ def __init__( self.train_df = pd.DataFrame(columns=["step", "prompt", "response", "reward"]) self.val_df = pd.DataFrame(columns=["step", "prompt", "response", "reward"]) - self.timer = SyncTimer( - reduction="mean", sync_cuda=True, buffer_size=1, reduce_op=torch.distributed.ReduceOp.MAX - ) + self.timer = ScopedTimer(reduction="mean", sync_cuda=True, buffer_size=1) def generate_ppo_data(self, rollout_batch): """generate ppo specific data for training @@ -285,7 +283,6 @@ def _run_inference(self, dataloader_builder, consumed_samples, is_validation): reshard_context = trt_llm_reshard_region if self.trtllm_reshard else nullcontext rollout_batches, futures = [], [] - timer_metrics = {} with reshard_context(): # dataloader must be built within the reshard context because it uses DP rank and size @@ -296,20 +293,17 @@ def _run_inference(self, dataloader_builder, consumed_samples, is_validation): # so the DP groups are correct num_microbatches = compute_num_rollout_microbatches(dataloader) - self.timer.start("batch_iterator_init") - batch_iterator = self.batch_iterator_cls( - sampler_iter, num_microbatches, dataloader.dataset, self.collate_fn - ) - timer_metrics["batch_iterator_init"] = self.timer.stop_and_get_time("batch_iterator_init") - - self.timer.start("generate") - for batch in batch_iterator: - rollout_batch = self.model.infer(batch) - rollout_batches.append(rollout_batch) + with self.timer("batch_iterator_init"): + batch_iterator = self.batch_iterator_cls( + sampler_iter, num_microbatches, dataloader.dataset, self.collate_fn + ) - futures.append(self.rm_critic.infer_rm_critic(rollout_batch)) + with self.timer("generate"): + for batch in batch_iterator: + rollout_batch = self.model.infer(batch) + rollout_batches.append(rollout_batch) - timer_metrics["generate"] = self.timer.stop_and_get_time("generate") + futures.append(self.rm_critic.infer_rm_critic(rollout_batch)) unbalanced_local_batch = PPORolloutBatch.from_rollout_batches( rollout_batches, @@ -329,26 +323,23 @@ def _run_inference(self, dataloader_builder, consumed_samples, is_validation): # since we compute the logprobs in nemo we need to disable the resharding batched_response_tokens = balanced_local_batch["response_tokens"] - self.timer.start("logprobs") - rollout_logprobs = self.model.get_inference_log_probs(batched_response_tokens) - balanced_local_batch["logprobs"] = rollout_logprobs - timer_metrics["logprobs"] = self.timer.stop_and_get_time("logprobs") + with self.timer("logprobs"): + rollout_logprobs = self.model.get_inference_log_probs(batched_response_tokens) + balanced_local_batch["logprobs"] = rollout_logprobs compute_init_policy_kl = not is_validation and self.compute_init_policy_kl if compute_init_policy_kl: - self.timer.start("init_logprobs") - rollout_init_logprobs = self.model.get_init_policy_logprobs(batched_response_tokens) - balanced_local_batch["init_logprobs"] = rollout_init_logprobs - timer_metrics["init_logprobs"] = self.timer.stop_and_get_time("init_logprobs") + with self.timer("init_logprobs"): + rollout_init_logprobs = self.model.get_init_policy_logprobs(batched_response_tokens) + balanced_local_batch["init_logprobs"] = rollout_init_logprobs # we send the request in sharded context, so we need to keep this sharding and then undo it with reshard_context(): - self.timer.start("critic_wait") - rm_value_rollout_batches = [] - for future in futures: - rewards, values = future.result() if isinstance(future, FutureResult) else future - rm_value_rollout_batches.append({"rewards": rewards, "values": values}) - timer_metrics["critic_wait"] = self.timer.stop_and_get_time("critic_wait") + with self.timer("critic_wait"): + rm_value_rollout_batches = [] + for future in futures: + rewards, values = future.result() if isinstance(future, FutureResult) else future + rm_value_rollout_batches.append({"rewards": rewards, "values": values}) unbalanced_rm_value_batch = PPORolloutBatch.from_rollout_batches( rm_value_rollout_batches, @@ -369,7 +360,7 @@ def _run_inference(self, dataloader_builder, consumed_samples, is_validation): global_rollout_batch.update(global_rm_value_batch) - return balanced_local_batch, cpu_dict(self.compute_rollout_metrics(global_rollout_batch)), timer_metrics + return balanced_local_batch, cpu_dict(self.compute_rollout_metrics(global_rollout_batch)) def compute_rollout_metrics(self, rollout_batch): table = {} @@ -406,22 +397,18 @@ def compute_rollout_metrics(self, rollout_batch): def run_validation(self): self.model.prepare_for_inference() - _, rollout_metrics, _ = self._run_inference( - self.val_dataloader_builder, consumed_samples=0, is_validation=True - ) + _, rollout_metrics = self._run_inference(self.val_dataloader_builder, consumed_samples=0, is_validation=True) self.model.finish_inference() return rollout_metrics @torch.no_grad() def generate_rollouts(self): - timing_metrics = {} + with self.timer("prepare_for_inference"): + # Timing includes build if first step and refit if step > 1 + self.model.prepare_for_inference() - self.timer.start("prepare_for_inference") - self.model.prepare_for_inference() - timing_metrics["prepare_for_inference"] = self.timer.stop_and_get_time("prepare_for_inference") - - rollout_batch, rollout_metrics, timer_metrics = self._run_inference( + rollout_batch, rollout_metrics = self._run_inference( self.train_dataloader_builder, consumed_samples=self.consumed_samples, is_validation=False ) @@ -429,23 +416,20 @@ def generate_rollouts(self): ppo_rollout_data, ppo_rollout_metrics = self.generate_ppo_data(rollout_batch) - self.timer.start("finish_inference") - self.model.finish_inference() - timing_metrics["finish_inference"] = self.timer.stop_and_get_time("finish_inference") - - timing_metrics.update(timer_metrics) + with self.timer("finish_inference"): + # Timing includes engine unloading if enabled + self.model.finish_inference() return ( ppo_rollout_data, rollout_metrics | ppo_rollout_metrics | {"consumed_samples": self.consumed_samples}, - timing_metrics, + self.timer.consume_durations(), ) def run_training(self, dataloader_iter): self.model.prepare_for_training() for batch in dataloader_iter: - self.timer.start("train_step_time") self.optimizer.zero_grad() self.model.prepare_for_training_step() @@ -466,7 +450,6 @@ def run_training(self, dataloader_iter): metrics["lr"] = lr metrics.update({"loss": loss_mean, "optim_step": self.ppo_optimization_step}) - metrics["train_step_time"] = self.timer.stop_and_get_time("train_step_time") self.logger.log_metrics( metrics, step=self.step, prefix="train_optim/", @@ -510,17 +493,18 @@ def fit(self): critic_train_loop_amount = self.critic_warmup_steps + 1 if self.step == 0 else 1 for _ in range(critic_train_loop_amount): - self.timer.start("rollout_time") clear_memory() - ppo_rollout_data, metrics, timer_metrics = self.generate_rollouts() - timing_metrics["rollout_time"] = self.timer.stop_and_get_time("rollout_time") + with self.timer("rollout_time"): + ppo_rollout_data, metrics, rollout_timer_metrics = self.generate_rollouts() + # Consume rollout_time + timing_metrics.update(self.timer.consume_durations()) # send critic train clear_memory() self.rm_critic.train(ppo_rollout_data) - timer_metrics = all_reduce_dict(timer_metrics, op=torch.distributed.ReduceOp.MAX) - timing_metrics.update(timer_metrics) + rollout_timer_metrics = all_reduce_dict(rollout_timer_metrics, op=torch.distributed.ReduceOp.MAX) + timing_metrics.update(rollout_timer_metrics) # logging table_metrics = metrics.pop("table") @@ -550,11 +534,12 @@ def fit(self): ) # start training clear_memory() - self.timer.start("train_time") - self.run_training(rollout_dataloader_iter) - timing_metrics["train_time"] = self.timer.stop_and_get_time("train_time") + with self.timer("train_time"): + self.run_training(rollout_dataloader_iter) - self.logger.log_metrics(timing_metrics, step=self.step, prefix="timers/") + self.logger.log_metrics( + timing_metrics | self.timer.consume_durations(), step=self.step, prefix="timers/" + ) self.step += 1 @@ -569,9 +554,10 @@ def fit(self): ) if run_val: - self.timer.start("validation_time") - val_metrics = self.run_validation() - timing_metrics["validation_time"] = self.timer.stop_and_get_time("validation_time") + with self.timer("validation_time"): + val_metrics = self.run_validation() + # Note: validation_time is logged one step behind (val step 5 means we've completed step 4) + timing_metrics.update(self.timer.consume_durations()) val_table_metrics = val_metrics.pop("table") diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py index 13bc9196b..275e02e82 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py @@ -85,7 +85,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): model_cfg=self.cfg, max_generation_length=self.cfg.ppo.length_params.get("max_length", 1024), max_input_len=self.cfg.ppo.trt_llm.get("max_input_len", 1024), - max_input_tokens=self.cfg.ppo.trt_llm.get("max_input_tokens", 4096), generation_batch_size=self.cfg.ppo.get("rollout_micro_batch_size", 4), unload_engine_train=self.cfg.ppo.trt_llm.get("unload_engine_train", False), trt_model_type=self.cfg.ppo.trt_llm.get("model_type", "llama"), @@ -98,6 +97,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): use_greedy=self.cfg.ppo.sampling_params.get("use_greedy", False), tokenizer=self.tokenizer, seed=self.cfg.ppo.trt_llm.get("seed", self.cfg.seed), + trt_model_dir=self.cfg.ppo.get("trt_model_dir", "/tmp/trt_llm_model"), ) # training calls diff --git a/nemo_aligner/utils/distributed.py b/nemo_aligner/utils/distributed.py index 3b93ea8a6..b4e5e64e8 100644 --- a/nemo_aligner/utils/distributed.py +++ b/nemo_aligner/utils/distributed.py @@ -14,6 +14,7 @@ """distributed utils for communicating between different ranks""" +import functools import time import warnings from collections import defaultdict @@ -23,12 +24,14 @@ from typing import Optional import torch +import torch.distributed from megatron.core import tensor_parallel from nemo.utils.timers import NamedTimer from nemo_aligner.utils import parallel_state from nemo_aligner.utils.parallel_state import get_model_parallel_group, get_model_parallel_src_rank from nemo_aligner.utils.ppo_utils import calculate_entropy +from nemo_aligner.utils.utils import deprecated_in_version def rebalance_nd_tensor(tensor, group): @@ -55,6 +58,7 @@ def rebalance_nd_tensor(tensor, group): return output_tensor +@deprecated_in_version("0.7.0", "Please use broadcast_tensor(tensor, src, group, dtype)") def broadcast_2d_tensor(tensor, src, group, dtype=torch.float32): """Broadcast any 2d tensor from the src rank to every other rank in the given group. All the ranks that send or receive data must call this function.""" @@ -82,6 +86,40 @@ def broadcast_2d_tensor(tensor, src, group, dtype=torch.float32): return tensor +def broadcast_tensor(tensor: torch.Tensor | None, src, group, dtype: torch.dtype | None = None): + """ + Broadcast a tensor from the source rank to every other rank in the given group. + All the ranks that send or receive data must call this function. + + Parameters: + - tensor: The tensor to be broadcasted (or None for non source ranks). + - src: The rank of the source tensor. + - group: The process group to use for the broadcast. + - dtype: (Optional) The desired data type to cast the tensor before broadcasting. + + Returns: + - The broadcasted tensor. + """ + + if torch.distributed.get_rank() == src: + tensor = tensor.cuda() + if dtype: + tensor = tensor.to(dtype) + + metadata = [tensor.dtype, tensor.shape] + + torch.distributed.broadcast_object_list(metadata, src, group) + torch.distributed.broadcast(tensor, src, group) + else: + metadata = [None, None] + torch.distributed.broadcast_object_list(metadata, src, group) + + dtype, input_shape = metadata + tensor = torch.empty(input_shape, dtype=dtype, device="cuda") + torch.distributed.broadcast(tensor, src, group) + return tensor + + def broadcast_2d_tensor_within_mp(tensor, dtype=torch.float32): """helper function to broadcast within the model parallel group """ @@ -93,11 +131,36 @@ def broadcast_2d_tensor_within_mp(tensor, dtype=torch.float32): return tensor -def broadcast_2d_tensor_within_pp(tensor, dtype=torch.float32): +@deprecated_in_version("0.7.0", "Please use broadcast_tensor_within_pp(tensor, dtype)") +def broadcast_2d_tensor_within_pp(tensor, dtype=torch.float32, from_last: bool = True): + """ + from_last: True=broadcast from the last PP rank and False=broadcast from first PP rank (default=True) + """ if parallel_state.get_pipeline_model_parallel_world_size() > 1: return broadcast_2d_tensor( tensor, - parallel_state.get_pipeline_model_parallel_last_rank(), + parallel_state.get_pipeline_model_parallel_last_rank() + if from_last + else parallel_state.get_pipeline_model_parallel_first_rank(), + parallel_state.get_pipeline_model_parallel_group(), + dtype=dtype, + ) + + return tensor + + +def broadcast_tensor_within_pp(tensor: torch.Tensor | None, dtype: torch.dtype = None, from_last: bool = True): + """ + tensor: Should be a valid tensor on src rank and None elsewhere + dtype: no dtype means that the dtype is inferred + from_last: True=broadcast from the last PP rank and False=broadcast from first PP rank (default=True) + """ + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + return broadcast_tensor( + tensor, + parallel_state.get_pipeline_model_parallel_last_rank() + if from_last + else parallel_state.get_pipeline_model_parallel_first_rank(), parallel_state.get_pipeline_model_parallel_group(), dtype=dtype, ) @@ -304,22 +367,27 @@ def all_reduce_dict(dictionary, dtype=torch.float32, group=None, op=torch.distri return dict(zip(keys, tensor.tolist())) +@deprecated_in_version("0.7.0", "Consider using ScopedTimer") class SyncTimer(NamedTimer): """Wrapper around NamedTimer to sync across DP ranks for more precise timing """ def __init__(self, *args, **kwargs): + # TODO: double check can delete self.reduce_op = kwargs.pop("reduce_op", torch.distributed.ReduceOp.MAX) super().__init__(*args, **kwargs) + # TODO: double check can delete self.stored_results = defaultdict(list) + # TODO: double check can delete def sync_time(self, list_to_sync): output_tensor = torch.tensor(list_to_sync, dtype=torch.float32, device="cuda") torch.distributed.all_reduce(output_tensor, op=self.reduce_op, group=parallel_state.get_data_parallel_group()) return output_tensor + # TODO: double check can delete def get_synced(self, *args, **kwargs): # time unsynced output = self.get(*args, **kwargs) @@ -331,6 +399,7 @@ def stop_and_get_time(self, name=""): self.stop(name=name) return self.get(name=name) + # TODO: double check can delete def store(self, name=""): """instead of immediately syncing the timing, we'll store it for a sync later on @@ -339,6 +408,7 @@ def store(self, name=""): output = self.get(name=name) self.stored_results[name].append(output) + # TODO: double check can delete def sync_and_consume_over_stored_time(self, name=""): """get the timings we stored, sync them and iterates over them this function consumes the timings (i.e remove them after iterating) @@ -352,6 +422,78 @@ def sync_and_consume_over_stored_time(self, name=""): del self.stored_results[name] +class ScopedTimer: + """ + A thin adapter over the NamedTimer class to help time sections of code + using a context manager. + + This class is useful for tracking timings automatically so you don't need + to manually collect them. You only need to pass the timer around and can + collect the durations in one place, instead of returning and mutating + dictionaries throughout your code. + + The ScopedTimer ensures that durations are logged and consumed properly, + preventing accidental overwriting of previous measurements. + + Usage: + timer = ScopedTimer() + + # All durations are logged in the timer + with timer("step_time"): + with timer("fwd"): + model.fwd() + with timer("bwd"): + model.bwd() + + # Consume all durations and reset internal store + durations = timer.consume_durations() + + # Durations that are not consumed will raise a ValueError + with timer("fwd"): + model.fwd() + with timer("fwd"): + model.fwd() # <-- This will raise an error as timer.consume_durations() + # is not called, meaning the previous measurement is + # still stored. + + Methods: + consume_durations() -> dict[str, float]: + Returns a dictionary of all logged durations and resets the internal log. + + __call__(name: str): + Context manager for timing a section of code. Raises a ValueError if + durations are not consumed before starting a new measurement for the + same name. + + Raises: + ValueError: If attempting to start a new timing section for a name that + already has a recorded duration without consuming the previous + measurement using consume_durations(). + """ + + def __init__(self, *args, **kwargs): + self._timer = NamedTimer(*args, **kwargs) + self._duration_log = {} + + def consume_durations(self) -> dict[str, float]: + durations = self._duration_log + self._duration_log = {} + return durations + + @contextmanager + def __call__(self, name: str): + try: + self._timer.start(name=name) + yield + finally: + self._timer.stop(name=name) + if name in self._duration_log: + raise ValueError( + f"Attempted to store new duration for {name=} before consuming last measurement. Call consume_durations() to consume the last set of measurements." + ) + self._duration_log[name] = self._timer.get(name=name) + + @dataclass class Timer: """Timer to tell us when the time limit is reached diff --git a/nemo_aligner/utils/text_generation_utils.py b/nemo_aligner/utils/text_generation_utils.py index 817f5dc45..ab4fdb501 100644 --- a/nemo_aligner/utils/text_generation_utils.py +++ b/nemo_aligner/utils/text_generation_utils.py @@ -109,11 +109,12 @@ def verify_is_valid_and_clamp_range_( if end_strings is None: end_strings = [] + mask = (0 <= response_tokens) & (response_tokens < tokenizer.vocab_size) + response_tokens.clamp_(min=0, max=tokenizer.vocab_size - 1) + prev = response_tokens[torch.arange(response_tokens.size(0)), response_lengths - 1] is_valid = strategy.end_of_generation_condition(response_tokens, prev, tokenizer.eos_id, end_strings) - mask = (0 <= response_tokens) & (response_tokens < tokenizer.vocab_size) is_valid = is_valid & torch.all(mask, dim=-1) - response_tokens.clamp_(0, tokenizer.vocab_size - 1) return is_valid diff --git a/nemo_aligner/utils/trt_llm.py b/nemo_aligner/utils/trt_llm.py index 7487a4e70..1f879064d 100644 --- a/nemo_aligner/utils/trt_llm.py +++ b/nemo_aligner/utils/trt_llm.py @@ -1,20 +1,31 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import secrets - import tensorrt_llm import torch +import torch.distributed from nemo.export.tensorrt_llm import TensorRTLLM from nemo.export.trt_llm import tensorrt_llm_run from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import build_tokenizer from nemo.utils import logging -from nemo_aligner.utils import parallel_state -from nemo_aligner.utils.distributed import broadcast_2d_tensor_within_mp -from nemo_aligner.utils.utils import log_memory +from nemo_aligner.utils.distributed import broadcast_2d_tensor_within_mp, broadcast_tensor_within_pp +from nemo_aligner.utils.utils import clear_memory, log_memory try: - from tensorrt_llm.bindings import GptSession + import tensorrt_llm - GptSession.refit_engine # check if TRTLLM Cpp runtime was compiled with engine refitting HAVE_TRTLLM = True except (ImportError, ModuleNotFoundError) as e: logging.info(f"got error message {e} when importing trt-llm dependencies, disabling") @@ -48,10 +59,9 @@ def __init__( repetition_penalty=1.0, max_generation_length=1024, max_input_len=1024, - max_input_tokens=4096, generation_batch_size=4, use_greedy=False, - trt_model_type="GPTForCausalLM", + trt_model_type="llama", seed=None, unload_engine_train=False, reshard_model=False, @@ -68,6 +78,11 @@ def __init__( assert ( tokenizer.pad_id != tokenizer.eos_id ), f"We require tokenizers to have a different {tokenizer.pad_id=} than {tokenizer.eos_id=} when using TRT-LLM. This is to make sure all code goes into the same path and include the eos_id when the response lengths are computed" + assert max_input_len > 0 + assert max_generation_length > 0 + assert ( + max_input_len + max_generation_length <= model_cfg.encoder_seq_length + ), f"We require max_input_len ({max_input_len}) + max_generation_length ({max_generation_length}) <= model_cfg.encoder_seq_length ({model_cfg.encoder_seq_length})" if use_greedy and sample_top_k != 1: logging.warning(f"'use_greedy=True' overrides {sample_top_k=} to 1") @@ -77,7 +92,6 @@ def __init__( self.tokenizer = tokenizer self.max_generation_length = max_generation_length self.max_input_len = max_input_len - self.max_input_tokens = max_input_tokens self.generation_batch_size = generation_batch_size self.unload_engine_train = unload_engine_train self.trt_model_type = trt_model_type @@ -115,9 +129,12 @@ def __init__( ids = append_and_repad_list(ids, self.eos_id, pad_id=0) offsets = append_and_repad_list(offsets, max(offsets) + 1, pad_id=-1) - assert max(offsets) == len(ids), "offset and stop token length are mismatched" - stop_list = torch.as_tensor([ids, offsets], dtype=torch.int32, device=torch.cuda.current_device()).repeat( - self.generation_batch_size, 1, 1 + assert max(offsets) == len(ids), f"offset and stop token length are mismatched ({max(offsets)=} {len(ids)=})" + # TRT-LLM expects stop_list to be a numpy array + stop_list = ( + torch.as_tensor([ids, offsets], dtype=torch.int32, device="cpu") + .repeat(self.generation_batch_size, 1, 1) + .numpy() ) self.sampling_config = tensorrt_llm.runtime.SamplingConfig( @@ -139,7 +156,7 @@ def __init__( def refit(self, model): if not self._trtllm_model_compiled: - log_memory("memory before TRT-LLM engine build") + log_memory("Before TRT-LLM engine build") global_devices = [None for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather_object(global_devices, torch.cuda.current_device()) gpus_per_node = max(global_devices) + 1 @@ -157,13 +174,16 @@ def refit(self, model): reshard_model=self.reshard_model, ) self._trtllm_model_compiled = True - log_memory("memory after TRT-LLM engine build") + log_memory("After TRT-LLM engine build") else: - log_memory("memory before TRT-LLM engine refit") + log_memory("Before TRT-LLM engine refit") self.trt_llm_exporter.refit(model, self.model_cfg) - log_memory("memory after TRT-LLM engine refit") + log_memory("After TRT-LLM engine refit") - def generate(self, inputs): + def _generate(self, inputs: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + """ + Internal API to make it easier to validate raw TRT-LLM outputs + """ prompt_tokens, prompt_lengths = inputs batch_input_ids = [] @@ -179,56 +199,21 @@ def generate(self, inputs): batch_input_ids=batch_input_ids, sampling_config=self.sampling_config, streaming=False ) - # remove beam dim from output_ids: [mbs, beam_dim, sequence len] - output_ids = torch.squeeze(output_dict["output_ids"], dim=1).long() - response_lengths = torch.squeeze(output_dict["sequence_lengths"], dim=1).long() - max_length = response_lengths.max().item() - - # TRTLLM with PP erroneously inserts padding: - # As an example when we have the input: - # [[prompt tok, PAD, PAD], [prompt tok, prompt tok, prompt tok]] - # The output when PP is enabled becomes: - # [[prompt tok, PAD, PAD, resp_tok, resp_tok], [prompt tok, prompt tok, prompt tok, resp_tok, resp_tok]] - # Therefore we need this logic to get rid of the padding in the middle of the tensor. - # Furthermore, TRTLLM only produces valid outputs on the source rank, so we can only process it here - # and rely on the aligner broadcast to get it to the other ranks. Miraculously, the length - # is still correct on the non src ranks - if ( - parallel_state.get_pipeline_model_parallel_world_size() > 1 - and parallel_state.get_model_parallel_src_rank() == torch.distributed.get_rank() - ): - valid_tokens = output_ids != self.pad_id - # we can't just naively use the response length here - # because there are cases where the model generates - # stop strings after it has stopped. so we need to - # be slightly inefficient and then remove the excess later on - valid_token_lengths = valid_tokens.sum(-1, keepdims=True) - max_unpadded_length = valid_token_lengths.max() - assert max_length <= max_unpadded_length, ( - "max unpadded length should be more or equal to max length. This assertion is probably happening because TRT-LLM considered a " - "pad tokens in the response length" - ) - - _output_ids = torch.full( - (response_lengths.size(0), max_unpadded_length), - fill_value=self.pad_id, - dtype=output_ids.dtype, - device=output_ids.device, - ) - - # only fill up to the amount of valid tokens - src_index_mask = ( - torch.arange(max_unpadded_length, device=response_lengths.device).view(1, -1) < valid_token_lengths - ) + # TRTLLM returns the output_ids and sequence_lengths only on the first PP rank, and None otherwise, so we need to broadcast + output_ids = broadcast_tensor_within_pp(output_dict["output_ids"] if output_dict else None, from_last=False) + response_lengths = broadcast_tensor_within_pp( + output_dict["sequence_lengths"] if output_dict else None, from_last=False + ) - _output_ids[src_index_mask] = output_ids[valid_tokens] + # remove beam dim from output_ids: [mbs, beam_dim, sequence len] + output_ids = torch.squeeze(output_ids, dim=1).long() + response_lengths = torch.squeeze(response_lengths, dim=1).long() + return output_ids, response_lengths - invalid_response_mask = torch.arange(max_unpadded_length, device=response_lengths.device).view( - 1, -1 - ) >= response_lengths.view(-1, 1) - _output_ids[invalid_response_mask] = self.pad_id + def generate(self, inputs: tuple[torch.Tensor, torch.Tensor]): - output_ids = _output_ids + output_ids, response_lengths = self._generate(inputs) + max_length = response_lengths.max().item() # Map pad_id to eos_id in case tokenizer does not have a pad_id output_ids[output_ids == self.pad_id] = self.eos_id @@ -242,7 +227,10 @@ def generate(self, inputs): return output - def free(self, force_unload=False): - if force_unload or self.unload_engine_train: - tensorrt_llm_run.tensorrt_llm_worker_context.decoder = None - tensorrt_llm_run.tensorrt_llm_worker_context = tensorrt_llm_run.TensorrtLLMWorkerContext() + def free(self): + if not self.unload_engine_train: + return + log_memory("Before TRT-LLM engine unload") + self.trt_llm_exporter.unload_engine() + clear_memory() + log_memory("After TRT-LLM engine unload") diff --git a/nemo_aligner/utils/utils.py b/nemo_aligner/utils/utils.py index 7701b5e78..8eb73d295 100644 --- a/nemo_aligner/utils/utils.py +++ b/nemo_aligner/utils/utils.py @@ -13,6 +13,7 @@ # limitations under the License. """Misc helper functions""" +import functools import gc import itertools import os @@ -521,4 +522,22 @@ def make_sharded_tensors_from_reference(reference_param, model_param, prefix: st def log_memory(prefix): pyt = torch.cuda.memory_allocated() / (1024 ** 3) el = (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / (1024 ** 3) - logging.info(f"Mem Usage | {prefix} | pytorch:{pyt} total_occupied:{el} | memory_other_than_pyt:{el-pyt}") + logging.info(f"Mem Usage (GB) | {prefix} | pytorch:{pyt} total_occupied:{el} | memory_other_than_pyt:{el-pyt}") + + +def deprecated_in_version(version: str, message: str | None = None): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Construct the deprecation message + func_name = func.__name__ + warn_message = ( + f"The function '{func_name}' is deprecated and will be removed in version {version}. " + f"{message if message else ''}".strip() + ) + warnings.warn(warn_message, DeprecationWarning, stacklevel=2) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/setup/trtllm.patch b/setup/trtllm.patch index 1f9c3d683..27dacae48 100644 --- a/setup/trtllm.patch +++ b/setup/trtllm.patch @@ -1,126 +1,175 @@ -diff --git a/cpp/include/tensorrt_llm/runtime/gptSession.h b/cpp/include/tensorrt_llm/runtime/gptSession.h -index c94eeb2a4..8fefe33af 100644 ---- a/cpp/include/tensorrt_llm/runtime/gptSession.h -+++ b/cpp/include/tensorrt_llm/runtime/gptSession.h -@@ -32,6 +32,7 @@ +diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py +index 527f8ccfd..222c69fc6 100644 +--- a/tensorrt_llm/builder.py ++++ b/tensorrt_llm/builder.py +@@ -660,10 +660,14 @@ class EngineConfig: + @classmethod + def from_json_file(cls, config_file): + with open(config_file) as f: +- config = json.load(f) +- return cls(PretrainedConfig.from_dict(config['pretrained_config']), +- BuildConfig.from_dict(config['build_config']), +- config['version']) ++ return cls.from_json_str(f.read()) ++ ++ @classmethod ++ def from_json_str(cls, config_str): ++ config = json.loads(config_str) ++ return cls(PretrainedConfig.from_dict(config['pretrained_config']), ++ BuildConfig.from_dict(config['build_config']), ++ config['version']) + + def to_dict(self): + build_config = self.build_config.to_dict() +@@ -770,6 +774,15 @@ class Engine: - #include - #include -+#include - #include - #include - #include -@@ -220,6 +221,8 @@ public: - void generate(GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig, - std::shared_ptr const generationProfiler = nullptr); + return cls(config, engine_buffer, managed_weights) -+ void refitEngine(std::vector> refit_params); ++ @classmethod ++ def from_buffer(cls, ++ engine_buffer: Union[trt.IHostMemory, bytes], ++ json_config_str: str, ++ rank: int = 0): ++ config = EngineConfig.from_json_str(json_config_str) ++ config.pretrained_config.set_rank(rank) ++ return cls(config, engine_buffer) + - private: - [[nodiscard]] bool useCudaGraphs() - { -diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp -index 3e6d704af..efa03e83f 100644 ---- a/cpp/tensorrt_llm/pybind/bindings.cpp -+++ b/cpp/tensorrt_llm/pybind/bindings.cpp + + def get_engine_version(engine_dir: str) -> Union[None, str]: + engine_dir = Path(engine_dir) +diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py +index 983d458b8..af8eceb7f 100755 +--- a/tensorrt_llm/runtime/generation.py ++++ b/tensorrt_llm/runtime/generation.py @@ -15,6 +15,7 @@ - * limitations under the License. - */ -+#include - #include - #include - #include -@@ -44,6 +45,9 @@ - #include "tensorrt_llm/runtime/memoryCounters.h" - #include "tensorrt_llm/runtime/samplingConfig.h" + import copy + import math ++import os + import platform + from collections import Counter + from dataclasses import dataclass, field +@@ -47,6 +48,10 @@ from ..quantization import QuantMode + from .kv_cache_manager import GenerationSequence, KVCacheUpdater + from .session import _scoped_stream -+#include -+#include ++# When variable is set, this will disable torch.cuda.set_device(...) calls ++# Useful in situations where device is already assigned by another library, i.e., megatron. ++DISABLE_TORCH_DEVICE_SET = os.environ.get("DISABLE_TORCH_DEVICE_SET", False) + - namespace py = pybind11; - namespace tb = tensorrt_llm::batch_manager; - namespace tbb = tensorrt_llm::batch_manager::batch_scheduler; -@@ -329,7 +333,23 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) - [](tr::GptSession& self, tpr::GenerationOutput& outputs, tpr::GenerationInput const& inputs, - tr::SamplingConfig const& samplingConfig) - { self.generate(*outputs.toTrtLlm(), *inputs.toTrtLlm(), samplingConfig); }, -- py::arg("outputs"), py::arg("inputs"), py::arg("sampling_config")); -+ py::arg("outputs"), py::arg("inputs"), py::arg("sampling_config")) -+ .def( -+ "refit_engine", -+ [](tr::GptSession& self, std::map refit_params, nvinfer1::DataType dtype) -+ { -+ std::vector> param_map; -+ for (auto param : refit_params) -+ { -+ nvinfer1::Weights trt_weight; -+ trt_weight.type = dtype; -+ trt_weight.count = param.second.numel(); -+ trt_weight.values = param.second.data_ptr(); -+ param_map.push_back({param.first, trt_weight}); -+ } -+ self.refitEngine(param_map); -+ }, -+ py::arg("refit_params"), py::arg("type")); - py::enum_(m, "LlmRequestState") - .value("REQUEST_STATE_UNKNOWN", tb::LlmRequestState_t::REQUEST_STATE_UNKNOWN) -diff --git a/cpp/tensorrt_llm/runtime/gptSession.cpp b/cpp/tensorrt_llm/runtime/gptSession.cpp -index 6e232f85d..81a5ef6ab 100644 ---- a/cpp/tensorrt_llm/runtime/gptSession.cpp -+++ b/cpp/tensorrt_llm/runtime/gptSession.cpp -@@ -1184,6 +1184,11 @@ void GptSession::finalize(SizeType microBatchId) - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - } + def decode_words_list(word_dict: List[List[str]], + tokenizer=None, +@@ -247,8 +252,11 @@ class _Runtime(object): + def __prepare(self, mapping: Mapping, engine_buffer): + self.runtime_rank = mapping.rank + local_rank = self.runtime_rank % mapping.gpus_per_node +- torch.cuda.set_device(local_rank) +- CUASSERT(cudart.cudaSetDevice(local_rank)) ++ if DISABLE_TORCH_DEVICE_SET: ++ CUASSERT(cudart.cudaSetDevice(torch.cuda.current_device())) ++ else: ++ torch.cuda.set_device(local_rank) ++ CUASSERT(cudart.cudaSetDevice(local_rank)) -+void GptSession::refitEngine(std::vector> refit_params) -+{ -+ mRuntime->refitEngine(*mLogger, refit_params); -+} -+ - void GptSession::CudaGraphExecutor::create(cudaGraph_t const& graph) - { - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); -diff --git a/cpp/tensorrt_llm/runtime/tllmRuntime.cpp b/cpp/tensorrt_llm/runtime/tllmRuntime.cpp -index 09261697c..87fe0a303 100644 ---- a/cpp/tensorrt_llm/runtime/tllmRuntime.cpp -+++ b/cpp/tensorrt_llm/runtime/tllmRuntime.cpp -@@ -217,6 +217,24 @@ void TllmRuntime::setOutputTensors(SizeType contextIndex, TensorMap& tensorMap) - } - } + self.runtime = trt.Runtime(logger.trt_logger) + self.engine = self.runtime.deserialize_cuda_engine(engine_buffer) +@@ -535,8 +543,10 @@ class SamplingConfig: + sink_token_length: Optional[int] = field(default=None) + output_sequence_lengths: bool = field(default=False) + return_dict: bool = field(default=False) +- stop_words_list: Optional[torch.Tensor] = field(default=None) +- bad_words_list: Optional[torch.Tensor] = field(default=None) ++ stop_words_list: Optional[Union[list, np.ndarray, ++ torch.Tensor]] = field(default=None) ++ bad_words_list: Optional[Union[list, np.ndarray, ++ torch.Tensor]] = field(default=None) -+void TllmRuntime::refitEngine( -+ nvinfer1::ILogger& logger, std::vector> refit_params) -+{ -+ nvinfer1::ICudaEngine& engine = *(mEngine.get()); -+ TLLM_CHECK_WITH_INFO(engine.isRefittable(), "Tried refitting engine without refit enabled"); -+ -+ nvinfer1::IRefitter* refitter = nvinfer1::createInferRefitter(engine, logger); -+ for (auto param : refit_params) -+ { -+ TLLM_CHECK_WITH_INFO( -+ refitter->setNamedWeights(param.first.c_str(), param.second, nvinfer1::TensorLocation::kHOST), -+ "Failed to refit %s", param.first.c_str()); -+ } -+ TLLM_CHECK_WITH_INFO(refitter->refitCudaEngine(), "Refit failed!"); -+ -+ delete refitter; -+} -+ - CudaStream const& TllmRuntime::getStream() const - { - return *mStream; -diff --git a/cpp/tensorrt_llm/runtime/tllmRuntime.h b/cpp/tensorrt_llm/runtime/tllmRuntime.h -index 51428f6f4..b32a754ca 100644 ---- a/cpp/tensorrt_llm/runtime/tllmRuntime.h -+++ b/cpp/tensorrt_llm/runtime/tllmRuntime.h -@@ -70,6 +70,8 @@ public: + temperature: Union[float, torch.Tensor] = field(default=1.0) + top_k: Union[int, torch.Tensor] = field(default=1) +@@ -698,9 +708,12 @@ class GenerationSession(object): + self._model_config = model_config + self.mapping = mapping + self.runtime = _Runtime(engine_buffer, mapping) +- self.device = torch.device( +- f'cuda:{self.runtime.runtime_rank % mapping.gpus_per_node}') +- torch.cuda.set_device(self.device) ++ if DISABLE_TORCH_DEVICE_SET: ++ self.device = torch.device(f'cuda:{torch.cuda.current_device()}') ++ else: ++ self.device = torch.device( ++ f'cuda:{self.runtime.runtime_rank % mapping.gpus_per_node}') ++ torch.cuda.set_device(self.device) + # dynamic_decoder currently use torch's current stream, so must let TRT enqueue use same stream here + self.stream = stream + if self.stream is None: +diff --git a/tensorrt_llm/runtime/model_runner.py b/tensorrt_llm/runtime/model_runner.py +index d2ba7edfa..e02310c3a 100644 +--- a/tensorrt_llm/runtime/model_runner.py ++++ b/tensorrt_llm/runtime/model_runner.py +@@ -31,10 +31,10 @@ from ..builder import Engine, EngineConfig, get_engine_version + from ..logger import logger + from ..mapping import Mapping + from ..quantization import QuantMode +-from .generation import (ChatGLMGenerationSession, GenerationSession, +- LogitsProcessor, LoraManager, ModelConfig, +- QWenForCausalLMGenerationSession, SamplingConfig, +- StoppingCriteria, to_word_list_format) ++from .generation import (DISABLE_TORCH_DEVICE_SET, ChatGLMGenerationSession, ++ GenerationSession, LogitsProcessor, LoraManager, ++ ModelConfig, QWenForCausalLMGenerationSession, ++ SamplingConfig, StoppingCriteria, to_word_list_format) + + + def get_engine_name(model: str, dtype: str, tp_size: int, pp_size: int, +@@ -554,7 +554,8 @@ class ModelRunner(ModelRunnerMixin): + + if MpiComm.size() > runtime_mapping.gpus_per_node: + assert MpiComm.local_size() == runtime_mapping.gpus_per_node +- torch.cuda.set_device(rank % runtime_mapping.gpus_per_node) ++ if not DISABLE_TORCH_DEVICE_SET: ++ torch.cuda.set_device(rank % runtime_mapping.gpus_per_node) + session = session_cls(model_config, + engine_buffer, + runtime_mapping, +@@ -656,7 +657,8 @@ class ModelRunner(ModelRunnerMixin): + assert model_config.max_medusa_tokens > 0, \ + "medusa_choice is specified but model_config.max_medusa_tokens is 0." - bool executeContext(SizeType contextIndex) const; +- torch.cuda.set_device(rank % runtime_mapping.gpus_per_node) ++ if not DISABLE_TORCH_DEVICE_SET: ++ torch.cuda.set_device(rank % runtime_mapping.gpus_per_node) + session = session_cls(model_config, + engine_buffer, + runtime_mapping, +@@ -840,12 +842,24 @@ class ModelRunner(ModelRunnerMixin): + batch_input_ids, input_lengths = self._prepare_inputs( + batch_input_ids, sampling_config.pad_id) -+ void refitEngine(nvinfer1::ILogger& logger, std::vector> refit_params); +- if sampling_config.bad_words_list is not None: +- sampling_config.bad_words_list = to_word_list_format( +- sampling_config.bad_words_list) +- if sampling_config.stop_words_list is not None: +- sampling_config.stop_words_list = to_word_list_format( +- sampling_config.stop_words_list) ++ def maybe_convert_to_words_list_format( ++ words_list: Optional[Union[list, np.ndarray, torch.Tensor]] ++ ) -> Optional[np.ndarray]: ++ if words_list is None or isinstance(words_list, np.ndarray): ++ return words_list ++ elif isinstance(words_list, torch.Tensor): ++ return words_list.numpy() ++ elif isinstance(words_list, list): ++ return to_word_list_format(words_list) ++ else: ++ raise TypeError( ++ f"Unexpected words_list type={type(words_list)}. Only list, np.ndarray, and torch.Tensor are supported." ++ ) + - CudaStream const& getStream() const; ++ sampling_config.bad_words_list = maybe_convert_to_words_list_format( ++ sampling_config.bad_words_list) ++ sampling_config.stop_words_list = maybe_convert_to_words_list_format( ++ sampling_config.stop_words_list) - BufferManager::CudaStreamPtr getStreamPtr() + if not self.kv_cache_type and sampling_config.max_new_tokens > 1: + raise RuntimeError( diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/conftest.py b/tests/conftest.py index a50564640..a9f49b28f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest +from omegaconf import DictConfig +from pytorch_lightning import Trainer + +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy +from nemo_aligner.models.nlp.gpt.megatron_gpt_ppo_actor import MegatronGPTActorModel +from nemo_aligner.utils.train_script_utils import init_distributed, resolve_and_create_trainer +from tests.test_mcore_utilities import Utils + +dir_path = os.path.dirname(os.path.abspath(__file__)) +# TODO: This file exists because in cases where TRTLLM MPI communicators are involved, +# the cleanup of the communicators can trigger a segfault at the end of a pytest +# run that gives a false negative when all tests passes. Instead we will use a file +# as a marker that the test has succeeded. +SUCCESS_FILE = os.path.join(dir_path, "PYTEST_SUCCESS") def pytest_addoption(parser): @@ -22,6 +40,7 @@ def pytest_addoption(parser): parser.addoption( "--cpu", action="store_true", help="pass that argument to use CPU during testing (DEFAULT: False = GPU)" ) + parser.addoption("--mpi", action="store_true", default=False, help="Run only MPI tests") def pytest_configure(config): @@ -44,3 +63,356 @@ def run_only_on_device_fixture(request, device): if request.node.get_closest_marker("run_only_on"): if request.node.get_closest_marker("run_only_on").args[0] != device: pytest.skip("skipped on this device: {}".format(device)) + + +@pytest.fixture +def init_model_parallel(): + from tests.test_mcore_utilities import Utils + + def initialize(*args, **kwargs): + Utils.initialize_model_parallel(*args, **kwargs) + + # Yield the initialized function, which is available to the test + yield initialize + + # Teardown: Called when the test ends + Utils.destroy_model_parallel() + + +@pytest.fixture +def llama3_tokenizer(): + return AutoTokenizer("meta-llama/Meta-Llama-3-8b") + + +@pytest.fixture +def dummy_gpt_model(init_model_parallel): + init_model_parallel(1, 1) + + model_cfg = { + "precision": 32, + "micro_batch_size": 4, + "global_batch_size": 8, + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "resume_from_checkpoint": None, + "encoder_seq_length": 512, + "max_position_embeddings": 512, + "num_layers": 1, + "hidden_size": 128, + "ffn_hidden_size": 512, + "num_attention_heads": 2, + "init_method_std": 0.02, + "hidden_dropout": 0.1, + "kv_channels": None, + "apply_query_key_layer_scaling": True, + "layernorm_epsilon": 1e-5, + "make_vocab_size_divisible_by": 128, + "pre_process": True, + "post_process": True, + "persist_layer_norm": True, + "gradient_as_bucket_view": True, + "tokenizer": {"library": "huggingface", "type": "meta-llama/Meta-Llama-3-8B", "use_fast": True,}, + "native_amp_init_scale": 4294967296, + "native_amp_growth_interval": 1000, + "hysteresis": 2, + "fp32_residual_connection": False, + "fp16_lm_cross_entropy": False, + "megatron_amp_O2": False, + "seed": 1234, + "use_cpu_initialization": False, + "onnx_safe": False, + "apex_transformer_log_level": 30, + "activations_checkpoint_method": None, + "activations_checkpoint_num_layers": 1, + "data": { + "data_prefix": "???", + "index_mapping_dir": None, + "data_impl": "mmap", + "splits_string": "900,50,50", + "seq_length": 512, + "skip_warmup": True, + "num_workers": 2, + "dataloader_type": "single", + "reset_position_ids": False, + "reset_attention_mask": False, + "eod_mask_loss": False, + }, + "optim": { + "name": "fused_adam", + "lr": 2e-4, + "weight_decay": 0.01, + "betas": [0.9, 0.98], + "sched": {"name": "CosineAnnealing", "warmup_steps": 500, "constant_steps": 50000, "min_lr": "2e-5"}, + }, + } + + trainer_cfg = { + "devices": 1, + "num_nodes": 1, + "accelerator": "gpu", + "precision": 32, + "logger": False, + "enable_checkpointing": False, + "use_distributed_sampler": False, + "max_epochs": 1000, + "max_steps": 100000, + "log_every_n_steps": 10, + "val_check_interval": 100, + "limit_val_batches": 50, + "limit_test_batches": 500, + "accumulate_grad_batches": 1, + "gradient_clip_val": 1.0, + } + + strategy = NLPDDPStrategy() + trainer = Trainer(strategy=strategy, **trainer_cfg) + cfg = DictConfig(model_cfg) + model = MegatronGPTModel(cfg=cfg, trainer=trainer) + yield model + + +@pytest.fixture +def dummy_actor_gpt_model_with_pp(): + + model_cfg = { + "precision": "bf16-mixed", + "micro_batch_size": 1, + "global_batch_size": 8, + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 2, + "resume_from_checkpoint": None, + "encoder_seq_length": 8192, + "max_position_embeddings": 8192, + "num_layers": 2, + "hidden_size": 128, + "ffn_hidden_size": 448, + "num_attention_heads": 4, + "init_method_std": 0.01, + "hidden_dropout": 0.0, + "kv_channels": None, + "apply_query_key_layer_scaling": True, + "layernorm_epsilon": 1e-5, + "make_vocab_size_divisible_by": 128, + "pre_process": True, + "post_process": True, + "persist_layer_norm": True, + "gradient_as_bucket_view": True, + "tokenizer": {"library": "huggingface", "type": "meta-llama/Meta-Llama-3-8B", "use_fast": True,}, + "native_amp_init_scale": 4294967296, + "native_amp_growth_interval": 1000, + "hysteresis": 2, + "fp32_residual_connection": False, + "fp16_lm_cross_entropy": False, + "megatron_amp_O2": True, + "seed": 1234, + "use_cpu_initialization": False, + "onnx_safe": False, + "apex_transformer_log_level": 30, + "activations_checkpoint_method": None, + "activations_checkpoint_num_layers": None, + "data": { + "data_impl": "mmap", + "splits_string": "99990,8,2", + "seq_length": 8192, + "skip_warmup": True, + "num_workers": 2, + "dataloader_type": "single", + "reset_position_ids": True, + "reset_attention_mask": True, + "eod_mask_loss": False, + "index_mapping_dir": None, + "data_prefix": [0.99, "/train-data",], + }, + "optim": { + "name": "distributed_fused_adam", + "lr": 0.0001, + "weight_decay": 0.1, + "betas": [0.9, 0.95], + "bucket_cap_mb": 125, + "overlap_grad_sync": True, + "overlap_param_sync": True, + "contiguous_grad_buffer": True, + "contiguous_param_buffer": True, + "sched": { + "name": "CosineAnnealing", + "warmup_steps": 500, + "constant_steps": 0, + "min_lr": 1e-5, + "max_steps": 2, + }, + "grad_sync_dtype": "bf16", + }, + "mcore_gpt": True, + "rampup_batch_size": None, + "virtual_pipeline_model_parallel_size": None, + "context_parallel_size": 1, + "num_query_groups": 2, + "use_scaled_init_method": True, + "attention_dropout": 0.0, + "ffn_dropout": 0.0, + "normalization": "rmsnorm", + "do_layer_norm_weight_decay": False, + "bias": False, + "activation": "fast-swiglu", + "headscale": False, + "transformer_block_type": "pre_ln", + "openai_gelu": False, + "normalize_attention_scores": True, + "position_embedding_type": "rope", + "rotary_percentage": 1.0, + "apply_rope_fusion": True, + "cross_entropy_loss_fusion": True, + "attention_type": "multihead", + "share_embeddings_and_output_weights": False, + "grad_allreduce_chunk_size_mb": 125, + "grad_div_ar_fusion": True, + "gradient_accumulation_fusion": True, + "bias_activation_fusion": True, + "bias_dropout_add_fusion": True, + "masked_softmax_fusion": True, + "sync_batch_comm": False, + "num_micro_batches_with_partial_activation_checkpoints": None, + "activations_checkpoint_layers_per_pipeline": None, + "sequence_parallel": False, + "deterministic_mode": False, + "transformer_engine": True, + "fp8": False, + "fp8_e4m3": False, + "fp8_hybrid": False, + "fp8_margin": 0, + "fp8_interval": 1, + "fp8_amax_history_len": 1024, + "fp8_amax_compute_algo": "max", + "ub_tp_comm_overlap": False, + "use_flash_attention": True, + "gc_interval": 2, + "nsys_profile": { + "enabled": False, + "trace": ["nvtx", "cuda"], + "start_step": 10, + "end_step": 10, + "ranks": [0], + "gen_shape": False, + }, + "dist_ckpt_format": "zarr", + "dist_ckpt_load_on_device": True, + "dist_ckpt_parallel_save": False, + "target": "nemo.collections.nlp.models.language_modeling.megatron_gpt_model.MegatronGPTModel", + "nemo_version": "2.0.0rc1", + "ppo": { + "rollout_micro_batch_size": 8, + "num_rollout_samples": 512, + "forward_micro_batch_size": 2, + "val_rollout_micro_batch_size": 2, + "num_val_samples": 1, + "offload_adam_states": True, + "entropy_bonus": 0.0, + "ratio_eps": 0.2, + "sampling_params": { + "use_greedy": False, + "temperature": 1.0, + "top_k": 0, + "top_p": 1.0, + "repetition_penalty": 1.0, + "add_BOS": False, + "all_probs": False, + "compute_logprob": False, + "end_strings": ["<|endoftext|>", ""], + }, + "length_params": {"max_length": 512, "min_length": 1}, + }, + } + + trainer_cfg = { + "devices": 2, + "num_nodes": 1, + "accelerator": "gpu", + "precision": 32, + "logger": False, + "enable_checkpointing": False, + "use_distributed_sampler": False, + "max_epochs": 1000, + "max_steps": 100000, + "log_every_n_steps": 10, + "val_check_interval": 100, + "limit_val_batches": 50, + "limit_test_batches": 500, + "accumulate_grad_batches": 1, + "gradient_clip_val": 1.0, + "ppo": { + "critic_warmup_steps": 0, + "max_epochs": 1, + "max_steps": -1, + "val_check_interval": 10, + "save_interval": 10, + "gradient_clip_val": 1.0, + "initial_policy_kl_penalty": 0.01, + "use_absolute_kl": True, + "discount_factor": 1.0, + "gae_lambda": 0.95, + "normalize_advantages": True, + "rollout_batch_seq_length": None, + "trt_llm": { + "enable": True, + "reshard": False, + "max_input_len": 256, + "seed": 42, + "model_type": "llama", + "unload_engine_train": True, + }, + }, + } + cfg = { + "trainer": trainer_cfg, + "model": model_cfg, + } + cfg = DictConfig(cfg) + + trainer = resolve_and_create_trainer(cfg, "ppo") + model = MegatronGPTActorModel(cfg=cfg.model, trainer=trainer) + init_distributed(trainer, model, cfg.model.get("transformer_engine", False)) + yield model + Utils.destroy_model_parallel() + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "run_only_on(device): runs the test only on a given device [CPU | GPU]", + ) + config.addinivalue_line("markers", "mpi(reason=None): marks tests as requiring MPI") + + +def pytest_collection_modifyitems(config, items): + run_mpi_tests = config.getoption("--mpi") + + # Skip all mpi tests if --mpi is not provided + if not run_mpi_tests: + skip_mpi = pytest.mark.skip(reason="Skipping MPI test: --mpi option not provided") + for item in items: + if "mpi" in item.keywords: + item.add_marker(skip_mpi) + else: + # If --mpi is provided, only run mpi tests, skip all others + skip_non_mpi = pytest.mark.skip(reason="Skipping non-MPI test: --mpi option provided") + for item in items: + if "mpi" not in item.keywords: + item.add_marker(skip_non_mpi) + + +def pytest_sessionstart(session): + # Remove the file at the start of the session, if it exists + if os.path.exists(SUCCESS_FILE): + os.remove(SUCCESS_FILE) + + +def pytest_sessionfinish(session, exitstatus): + """ whole test run finishes. """ + import torch + + # After the test session completes, destroy the NCCL process group. This suppresses a NCCL warning from pytorch>=2.4 + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + if exitstatus == 0: + with open(SUCCESS_FILE, "w") as f: + ... diff --git a/tests/functional/ppo.sh b/tests/functional/ppo.sh index d1af6bc31..06c111b34 100644 --- a/tests/functional/ppo.sh +++ b/tests/functional/ppo.sh @@ -10,8 +10,10 @@ export NVTE_APPLY_QK_LAYER_SCALING=1 KL=${KL:-0.03} LR=${LR:-9e-7} RUN_ONLY=${RUN_ONLY:-} -GBS=${GBS:-64} -RESHARD=${RESHARD:-False} +GBS=${GBS:-2} +TP_SIZE=${TP_SIZE:-1} +PP_SIZE=${PP_SIZE:-2} +RESHARD=${RESHARD:-True} RM_NEMO_FILE=${RM_NEMO_FILE} ACTOR_NEMO_FILE=${ACTOR_NEMO_FILE} @@ -95,7 +97,6 @@ if [[ -z "${FAST:-}" ]]; then fi ######################################################################################### -# START HETEROGENEUS JOB 3 CONF_DIR="${GPFS}/examples/nlp/gpt/conf/" CONF_NAME="gpt_ppo_actor" @@ -133,14 +134,14 @@ mpirun -np 2 --allow-run-as-root python -u ${GPFS}/examples/nlp/gpt/train_gpt_pp trainer.ppo.val_check_interval=2 \ ++trainer.ppo.save_interval=2 \ ++model.micro_batch_size=1 \ - ++model.global_batch_size=1 \ - ++model.tensor_model_parallel_size=1 \ - ++model.pipeline_model_parallel_size=2 \ + ++model.global_batch_size=${GBS} \ + ++model.tensor_model_parallel_size=${TP_SIZE} \ + ++model.pipeline_model_parallel_size=${PP_SIZE} \ ++model.ppo.entropy_bonus=0.0 \ ++model.ppo.ratio_eps=0.2 \ ++model.encoder_seq_length=64 \ ++exp_manager.checkpoint_callback_params.save_top_k=10 \ - ++model.ppo.num_rollout_samples=1 \ + ++model.ppo.num_rollout_samples=${GBS} \ ++model.ppo.rollout_micro_batch_size=1 \ ++model.ppo.length_params.max_length=32 \ ++model.ppo.forward_micro_batch_size=1 \ @@ -167,7 +168,7 @@ mpirun -np 2 --allow-run-as-root python -u ${GPFS}/examples/nlp/gpt/train_gpt_pp trainer.ppo.max_steps=3 \ trainer.ppo.trt_llm.model_type=llama \ ++exp_manager=null \ - remote_critic_rm.pad_to_length=$((512+256)) $@ # (match critic) generation + prompt = model.ppo.length_params.max_length + model.ppo.trt_llm.max_input_len ( 512) = self.trtllm_generate.max_generation_length + self.trtllm_generate.max_input_len + remote_critic_rm.pad_to_length=$((512+256)) $@ # (match critic) generation + prompt = model.ppo.length_params.max_length + model.ppo.trt_llm.max_input_len (512) = self.trtllm_generate.max_generation_length + self.trtllm_generate.max_input_len } actor_log_file=$(mktemp /tmp/actor-ppo-log-XXXXXX) diff --git a/tests/functional/test_cases/dpo-llama3 b/tests/functional/test_cases/dpo-llama3 new file mode 100755 index 000000000..8e40e94c8 --- /dev/null +++ b/tests/functional/test_cases/dpo-llama3 @@ -0,0 +1,22 @@ +#!/bin/bash +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR + +set -eoux pipefail + +PRETRAINED_CHECKPOINT_NEMO_FILE=${ALIGNER_CI_DIR}/checkpoints/tiny-llama3-results-nlayers2-hidden128-ffn448-nhead4-qgroup2-megatron_gpt.nemo \ +bash ../dpo.sh diff --git a/tests/functional/test_cases/dpo-llama3.sh b/tests/functional/test_cases/dpo-llama3.sh deleted file mode 100644 index 2210a8eec..000000000 --- a/tests/functional/test_cases/dpo-llama3.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -cd $SCRIPT_DIR - -set -eoux pipefail - -PRETRAINED_CHECKPOINT_NEMO_FILE=${ALIGNER_CI_DIR}/checkpoints/tiny-llama3-results-nlayers2-hidden128-ffn448-nhead4-qgroup2-megatron_gpt.nemo \ -bash ../dpo.sh \ No newline at end of file diff --git a/tests/functional/test_cases/ppo-llama3-pp2-reshard b/tests/functional/test_cases/ppo-llama3-pp2-reshard new file mode 100755 index 000000000..9169b10da --- /dev/null +++ b/tests/functional/test_cases/ppo-llama3-pp2-reshard @@ -0,0 +1,28 @@ +#!/bin/bash + +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR + +set -eoux pipefail + +GBS=2 \ + TP_SIZE=1 \ + PP_SIZE=2 \ + RESHARD=True \ + RM_NEMO_FILE=${ALIGNER_CI_DIR}/checkpoints/llama3--nlayers4-hidden64-ffn224-dummy_rm-megatron_gpt.nemo \ + ACTOR_NEMO_FILE=${ALIGNER_CI_DIR}/checkpoints/tiny-llama3-results-nlayers2-hidden128-ffn448-nhead4-qgroup2-megatron_gpt.nemo \ + bash ../ppo.sh diff --git a/tests/functional/test_cases/ppo-pp-llama3.sh b/tests/functional/test_cases/ppo-pp-llama3.sh deleted file mode 100644 index 5e071e77d..000000000 --- a/tests/functional/test_cases/ppo-pp-llama3.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -cd $SCRIPT_DIR - -set -eoux pipefail - -RM_NEMO_FILE=${ALIGNER_CI_DIR}/checkpoints/llama3--nlayers4-hidden64-ffn224-dummy_rm-megatron_gpt.nemo \ -ACTOR_NEMO_FILE=${ALIGNER_CI_DIR}/checkpoints/tiny-llama3-results-nlayers2-hidden128-ffn448-nhead4-qgroup2-megatron_gpt.nemo \ -bash ../ppo.sh \ No newline at end of file diff --git a/tests/run_mpi_unit.sh b/tests/run_mpi_unit.sh new file mode 100755 index 000000000..e11e5cf10 --- /dev/null +++ b/tests/run_mpi_unit.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR + +NUM_GPUS_AVAILABLE=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + +if [[ $NUM_GPUS_AVAILABLE -lt 2 ]]; then + echo "[ERROR]: Unit tests require at least 2 gpus" + exit 1 +fi + +export PYTHONPATH=$(realpath ..):${PYTHONPATH:-} +CUDA_VISIBLE_DEVICES=0,1 mpirun -np 2 --allow-run-as-root pytest .. -rA -s -x -vv --mpi $@ || true + +if [[ -f PYTEST_SUCCESS ]]; then + echo SUCCESS +else + echo FAILURE + exit 1 +fi diff --git a/tests/run_unit.sh b/tests/run_unit.sh index ff91b6926..41216da52 100755 --- a/tests/run_unit.sh +++ b/tests/run_unit.sh @@ -1,4 +1,17 @@ #!/bin/bash +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) cd $SCRIPT_DIR @@ -11,4 +24,11 @@ if [[ $NUM_GPUS_AVAILABLE -lt 2 ]]; then fi export PYTHONPATH=$(realpath ..):${PYTHONPATH:-} -CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 -m pytest .. -rA -s -x $@ +CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 -m pytest .. -rA -s -x -vv $@ || true + +if [[ -f PYTEST_SUCCESS ]]; then + echo SUCCESS +else + echo FAILURE + exit 1 +fi diff --git a/tests/test_distributed.py b/tests/test_distributed.py index ca56c27ae..56904b6ac 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import megatron.core.parallel_state as mcore_parallel_state +import numpy as np import pytest import torch +import torch.distributed from megatron.core import tensor_parallel from nemo_aligner.utils import parallel_state from nemo_aligner.utils.distributed import ( + broadcast_tensor_within_pp, calculate_distributed_entropy, from_parallel_logits_to_logprobs, masked_global_mean_var, @@ -228,3 +232,43 @@ def test_distributed_log_probs(self, batch_size, seed, dtype, atol, rtol, higher @pytest.mark.parametrize("batch_size,seed", [(1, 5555), (4, 6666)]) def test_distributed_entropy(self, batch_size, seed): self._run_test(self._test_distributed_entropy, batch_size, seed) + + +@pytest.mark.run_only_on("GPU") +@pytest.mark.parametrize( + "tp_size, pp_size, from_last, shape, dtype, override_dtype", + [ + (1, 2, True, (2,), torch.float32, False), + (1, 2, False, (2, 3), torch.bfloat16, True), + (1, 2, True, (2, 3, 4), torch.float16, False), + (1, 2, False, (2,), torch.int32, True), + ], +) +def test_broadcast_within_pp(init_model_parallel, tp_size, pp_size, from_last, shape, dtype, override_dtype): + init_model_parallel(tensor_model_parallel_size=tp_size, pipeline_model_parallel_size=pp_size) + num_el = np.product(shape) + expected = torch.arange(num_el, dtype=dtype).reshape(shape) + + tensor = None + if ( + from_last + and mcore_parallel_state.get_pipeline_model_parallel_rank() + == mcore_parallel_state.get_pipeline_model_parallel_last_rank() + ): + tensor = expected + elif ( + not from_last + and mcore_parallel_state.get_pipeline_model_parallel_rank() + == mcore_parallel_state.get_pipeline_model_parallel_first_rank() + ): + tensor = expected + + if override_dtype: + # For posterity, this was default behavior when type wasn't set + out_tensor = broadcast_tensor_within_pp(tensor, from_last=from_last, dtype=torch.float32) + assert out_tensor.dtype == torch.float32 + torch.testing.assert_close(out_tensor.to("cpu"), expected.to("cpu").type(torch.float32)) + else: + out_tensor = broadcast_tensor_within_pp(tensor, from_last=from_last) + assert out_tensor.dtype == dtype + torch.testing.assert_close(out_tensor.to("cpu"), expected.to("cpu")) diff --git a/tests/test_mcore_utilities.py b/tests/test_mcore_utilities.py new file mode 100644 index 000000000..96dd3da30 --- /dev/null +++ b/tests/test_mcore_utilities.py @@ -0,0 +1,109 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE: Copied from https://github.com/NVIDIA/Megatron-LM/blob/main/tests/unit_tests/test_utilities.py +# to avoid having to rely on megatron repo for tests when all we need is this file + +import os +from datetime import timedelta + +import megatron.core.parallel_state as ps +import torch +from torch._C._distributed_c10d import PrefixStore +from torch.distributed import rendezvous +from torch.distributed.distributed_c10d import _store_based_barrier + + +class TestModel(torch.nn.Module): + def __init__(self, input_dim: int, output_dim: int, num_layers: int, bias: bool): + super().__init__() + self.layers = torch.nn.ModuleList([torch.nn.Linear(input_dim, output_dim, bias) for _ in range(num_layers)]) + + +class Utils: + + inited = False + store = None + + @classmethod + @property + def world_size(cls): + """Lazily grab device count""" + return torch.cuda.device_count() + + @classmethod + @property + def rank(cls): + """Lazily grab rank""" + return int(os.environ["LOCAL_RANK"]) + + @staticmethod + def initialize_distributed(): + if not torch.distributed.is_initialized() and Utils.rank >= 0: + print(f"Initializing torch.distributed with rank: {Utils.rank}, world_size: {Utils.world_size}") + torch.cuda.set_device(Utils.rank % torch.cuda.device_count()) + init_method = "tcp://" + master_ip = os.getenv("MASTER_ADDR", "localhost") + master_port = os.getenv("MASTER_PORT", "6000") + init_method += master_ip + ":" + master_port + rendezvous_iterator = rendezvous(init_method, Utils.rank, Utils.world_size, timeout=timedelta(minutes=1)) + store, _, _ = next(rendezvous_iterator) + store.set_timeout(timedelta(minutes=1)) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore("default_pg", store) + Utils.store = store + + torch.distributed.init_process_group( + backend="nccl", world_size=Utils.world_size, rank=Utils.rank, store=store + ) + + torch.distributed.barrier() + Utils.inited = True + + @staticmethod + def set_world_size(world_size=None, rank=None): + Utils.world_size = torch.cuda.device_count() if world_size is None else world_size + if torch.distributed.is_initialized() and Utils.world_size != torch.distributed.get_world_size(): + torch.distributed.destroy_process_group() + + if rank is None: + Utils.rank = int(os.environ["LOCAL_RANK"]) + if Utils.rank >= Utils.world_size: + Utils.rank = -1 + else: + Utils.rank = rank + + @staticmethod + def destroy_model_parallel(): + if not Utils.inited: + return + torch.distributed.barrier() + ps.destroy_model_parallel() + Utils.inited = False + + @staticmethod + def initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + **kwargs, + ): + ps.destroy_model_parallel() + Utils.initialize_distributed() + ps.initialize_model_parallel( + tensor_model_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size, **kwargs, + ) + Utils.inited = True diff --git a/tests/test_text_generation_utils.py b/tests/test_text_generation_utils.py index 26b90a53b..0eef72147 100644 --- a/tests/test_text_generation_utils.py +++ b/tests/test_text_generation_utils.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo_aligner.utils.text_generation_utils import tokenize_batch +import torch + +from nemo_aligner.utils.text_generation_utils import ( + TrackLengthGPTModelTextGenerationStrategy, + tokenize_batch, + verify_is_valid_and_clamp_range_, +) class MockTokenizer: @@ -67,3 +73,42 @@ def test_tokenize_batch_with_sentence_longer_than_max_len(): 10, 10, ], f"expected context_length_tensor to be [10, 10] but got {context_length_tensor.tolist()}" + + +def test_verify_is_valid_and_clamp_range(dummy_gpt_model): + max_gen_length = 8 + + random_gen = [9, 8] # chosen arbitrarily + extra_id_1_ids = dummy_gpt_model.tokenizer.text_to_ids("") + extra_id_2_ids = dummy_gpt_model.tokenizer.text_to_ids("") + eos_id = dummy_gpt_model.tokenizer.eos_id + + # response contains prompt + generation + response_tokens = [ + [1] + random_gen, # doesn't end with an eos + [1, 1] + random_gen + [eos_id], + [1] + random_gen + extra_id_1_ids, + [1, 1] + random_gen + extra_id_1_ids, + [1] + random_gen + extra_id_2_ids, + ] + + # The padding has to be eos_id + response_tokens = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(x) for x in response_tokens], batch_first=True, padding_value=eos_id + ) + + context_lengths = torch.tensor([1, 2, 1, 2, 1]) + generation_lengths = torch.tensor([0, 1, len(extra_id_1_ids), len(extra_id_2_ids), len(extra_id_2_ids)]) + len( + random_gen + ) + response_lengths = context_lengths + generation_lengths + + strategy = TrackLengthGPTModelTextGenerationStrategy(dummy_gpt_model, context_lengths, max_gen_length) + is_end = verify_is_valid_and_clamp_range_( + response_tokens=response_tokens, + response_lengths=response_lengths, + strategy=strategy, + tokenizer=dummy_gpt_model.tokenizer, + end_strings=[""], + ) + assert is_end.tolist() == [False, True, True, True, False] diff --git a/tests/test_trt_llm.py b/tests/test_trt_llm.py new file mode 100644 index 000000000..e008f482c --- /dev/null +++ b/tests/test_trt_llm.py @@ -0,0 +1,84 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from nemo_aligner.utils import parallel_state +from nemo_aligner.utils.trt_llm import GPTGenerateTRTLLM + + +@pytest.mark.mpi +def test_trtllm_does_not_insert_padding(dummy_actor_gpt_model_with_pp): + trtllm_generate = GPTGenerateTRTLLM( + model_cfg=dummy_actor_gpt_model_with_pp.cfg, + end_strings=[""], + tokenizer=dummy_actor_gpt_model_with_pp.tokenizer, + max_input_len=dummy_actor_gpt_model_with_pp.cfg.encoder_seq_length // 2, + max_generation_length=dummy_actor_gpt_model_with_pp.cfg.encoder_seq_length // 2, + ) + trtllm_generate.refit(dummy_actor_gpt_model_with_pp) + + batch_size = 4 + max_seq_len = dummy_actor_gpt_model_with_pp.cfg.encoder_seq_length + prompt_tokens = torch.ones((batch_size, max_seq_len), dtype=torch.int32) + prompt_lengths = torch.tensor([10, 20, 30, 40]) + + output_ids, response_lengths = trtllm_generate._generate([prompt_tokens, prompt_lengths]) + max_length = response_lengths.max().item() + + # TRTLLM with PP has sometimes erroneously inserts padding: + # As an example when we have the input: + # [[prompt tok, PAD, PAD], [prompt tok, prompt tok, prompt tok]] + # The output when PP is enabled becomes: + # [[prompt tok, PAD, PAD, resp_tok, resp_tok], [prompt tok, prompt tok, prompt tok, resp_tok, resp_tok]] + # Therefore we need this logic to get rid of the padding in the middle of the tensor. + # Furthermore, TRTLLM only produces valid outputs on the source rank, so we can only process it here + # and rely on the aligner broadcast to get it to the other ranks. Miraculously, the length + # is still correct on the non src ranks + if ( + parallel_state.get_pipeline_model_parallel_world_size() > 1 + and parallel_state.get_model_parallel_src_rank() == torch.distributed.get_rank() + ): + valid_tokens = output_ids != trtllm_generate.pad_id + # we can't just naively use the response length here + # because there are cases where the model generates + # stop strings after it has stopped. so we need to + # be slightly inefficient and then remove the excess later on + valid_token_lengths = valid_tokens.sum(-1, keepdims=True) + max_unpadded_length = valid_token_lengths.max() + assert max_length <= max_unpadded_length, ( + "max unpadded length should be more or equal to max length. This assertion is probably happening because TRT-LLM considered a " + "pad tokens in the response length" + ) + + _output_ids = torch.full( + (response_lengths.size(0), max_unpadded_length), + fill_value=trtllm_generate.pad_id, + dtype=output_ids.dtype, + device=output_ids.device, + ) + + # only fill up to the amount of valid tokens + src_index_mask = ( + torch.arange(max_unpadded_length, device=response_lengths.device).view(1, -1) < valid_token_lengths + ) + + _output_ids[src_index_mask] = output_ids[valid_tokens] + + invalid_response_mask = torch.arange(max_unpadded_length, device=response_lengths.device).view( + 1, -1 + ) >= response_lengths.view(-1, 1) + _output_ids[invalid_response_mask] = trtllm_generate.pad_id + + torch.testing.assert_close(output_ids, _output_ids)