Skip to content

Commit

Permalink
feat[dpo]: Allows DPO Dataset to pad to a multiple
Browse files Browse the repository at this point in the history
- Needed for sequence parallel where the sequence length needs to be
  divisible

Signed-off-by: Terry Kong <[email protected]>

missing plumbing

Signed-off-by: Terry Kong <[email protected]>

fix: update other APIs required for MCore dist opt

Signed-off-by: Terry Kong <[email protected]>

feat: more informative error for DPO dataset tokenization failure

Signed-off-by: Terry Kong <[email protected]>
  • Loading branch information
terrykong committed Nov 4, 2024
1 parent 8c4e61a commit 446f7f0
Show file tree
Hide file tree
Showing 20 changed files with 499 additions and 79 deletions.
4 changes: 4 additions & 0 deletions examples/nlp/gpt/conf/gpt_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# dpo specific args
dpo:
Expand All @@ -17,6 +18,7 @@ trainer:

# how many GBS we loop over
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down Expand Up @@ -57,6 +59,7 @@ model:
micro_batch_size: 1
global_batch_size: 64
megatron_amp_O2: True
variable_seq_lengths: ${not:model.data.pad_length_to_multiple_of}

dpo:
# This default value ensures there are no numeric differences beween trained and reference policies when computing log probs.
Expand Down Expand Up @@ -114,6 +117,7 @@ model:
data_impl: jsonl
splits_string: null
seq_length: ${model.encoder_seq_length}
pad_length_to_multiple_of: null # Use if sequence_parallel is enabled to ensure seq_length is divisible by the ...
skip_warmup: True
num_workers: 0
reset_position_ids: False # Reset position ids after end-of-document token
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_kto.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# kto specific args
kto:
Expand All @@ -17,6 +18,7 @@ trainer:

# how many GBS we loop over
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_ppo_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

ppo:
# How many steps we train warmup the critic for (without training the policy)
Expand All @@ -21,6 +22,7 @@ trainer:
max_steps: -1 # max PPO steps (-1 to go through the whole train set)
val_check_interval: 10
save_interval: ${.val_check_interval}
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# PPO args to generate the data for training
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_ppo_critic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

ppo:
port: 5556
Expand All @@ -15,6 +16,7 @@ trainer:

# used to set the learning rate scheduler
max_steps: 10000
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# a PyTriton parameter to specify
Expand Down
4 changes: 3 additions & 1 deletion examples/nlp/gpt/conf/gpt_rs_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

rs:
max_epochs: 1
max_steps: -1 # max rs steps (-1 to go through the whole train set)
val_check_interval: 10
save_interval: ${.val_check_interval}
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# pick up from the model
Expand Down Expand Up @@ -177,4 +179,4 @@ model:
# define fields from the base model's config that should be ignored when merging with this config.
overwrite_base_config:
data:
data_prefix: True
data_prefix: True
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ trainer:
devices: 1
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

sft:
max_epochs: 1
Expand All @@ -15,6 +16,7 @@ trainer:
limit_train_batches: 1.0

limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# can be used to register any custom metrics that require token-by-token generation
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_spin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16-mixed
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# spin specific args
spin:
Expand All @@ -18,6 +19,7 @@ trainer:

# how many GBS we loop over
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/training_rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# rm specific args
rm:
Expand All @@ -20,6 +21,7 @@ trainer:
# set to float for a percentage
# of the validation dataset
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down
3 changes: 3 additions & 0 deletions examples/nlp/gpt/train_gpt_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True)
OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True)
OmegaConf.register_new_resolver("not", lambda x: not x, replace=True)

mp.set_start_method("spawn", force=True)

Expand Down Expand Up @@ -110,6 +111,7 @@ def main(cfg) -> None:
reset_position_ids=cfg.model.data.get("reset_position_ids", False),
reset_attention_mask=cfg.model.data.get("reset_attention_mask", False),
eod_mask_loss=cfg.model.data.get("eod_mask_loss", False),
pad_length_to_multiple_of=cfg.model.data.get("pad_length_to_multiple_of", None),
),
)

Expand All @@ -127,6 +129,7 @@ def main(cfg) -> None:
reset_position_ids=cfg.model.data.get("reset_position_ids", False),
reset_attention_mask=cfg.model.data.get("reset_attention_mask", False),
eod_mask_loss=cfg.model.data.get("eod_mask_loss", False),
pad_length_to_multiple_of=cfg.model.data.get("pad_length_to_multiple_of", None),
),
use_random_sampler=False,
)
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/critic_server_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def run_training(self, tokens=None, returns=None, prev_values=None, mask=None):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

if grad_norm is not None:
Expand Down
45 changes: 39 additions & 6 deletions nemo_aligner/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from omegaconf.dictconfig import DictConfig
from tqdm import tqdm
import math

from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import (
MegatronPretrainingRandomBatchSampler,
Expand All @@ -30,7 +31,14 @@
from nemo_aligner.utils.utils import clear_memory


def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False):
def dpo_custom_collate(
batch: list,
eos_id: int,
reset_position_ids: bool = False,
reset_attention_mask: bool = False,
eod_mask_loss: bool = False,
pad_length_to_multiple_of: int | None = None,
):
chosen_tokens = [item["chosen"] for item in batch]
rejected_tokens = [item["rejected"] for item in batch]
chosen_lengths = torch.LongTensor([item["chosen_length"] for item in batch])
Expand All @@ -45,8 +53,32 @@ def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_
chosen_labels = torch.nn.utils.rnn.pad_sequence(chosen_labels, batch_first=True, padding_value=-100)
rejected_labels = torch.nn.utils.rnn.pad_sequence(rejected_labels, batch_first=True, padding_value=-100)

if pad_length_to_multiple_of:
chosen_padded_length = (
math.ceil(chosen_tokens.shape[1] / pad_length_to_multiple_of) * pad_length_to_multiple_of
)
rejected_padded_length = (
math.ceil(rejected_tokens.shape[1] / pad_length_to_multiple_of) * pad_length_to_multiple_of
)
chosen_tokens = torch.nn.functional.pad(
chosen_tokens, (0, chosen_padded_length - chosen_tokens.shape[1]), mode="constant", value=eos_id
)
rejected_tokens = torch.nn.functional.pad(
rejected_tokens, (0, rejected_padded_length - rejected_tokens.shape[1]), mode="constant", value=eos_id
)
chosen_labels = torch.nn.functional.pad(
chosen_labels, (0, chosen_padded_length - chosen_labels.shape[1]), mode="constant", value=eos_id
)
rejected_labels = torch.nn.functional.pad(
rejected_labels, (0, rejected_padded_length - rejected_labels.shape[1]), mode="constant", value=eos_id
)

attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
chosen_tokens, eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss,
chosen_tokens,
eos_id,
reset_position_ids,
reset_attention_mask,
eod_mask_loss,
)
assert attention_mask.ndim == 4, "attention_mask is incorrect shape for dpo_custom_collate"
if attention_mask.shape[0] == 1:
Expand All @@ -70,8 +102,7 @@ def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_


class DPOTrainer:
"""Trainer to coordinate DPO training
"""
"""Trainer to coordinate DPO training"""

def __init__(
self,
Expand Down Expand Up @@ -172,7 +203,7 @@ def train_single_step(self, global_batch):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

trainer_metrics = {}
Expand Down Expand Up @@ -233,7 +264,9 @@ def fit(self):
metrics["step_time"] = train_step_time
metrics["epoch"] = self.epoch + 1
self.logger.log_metrics(
metrics, step=self.step, prefix="train/",
metrics,
step=self.step,
prefix="train/",
)
metrics = {f"train_{k}": v for k, v in metrics.items()}

Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def run_training(self, dataloader_iter):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

if grad_norm is not None:
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def run_training(self, dataloader_iter):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

if grad_norm is not None:
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def train_single_step(self, global_batch):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

trainer_metrics = {}
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def train_single_step(self, batch):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

trainer_metrics = {}
Expand Down
28 changes: 22 additions & 6 deletions nemo_aligner/data/nlp/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np
import torch
import torch.utils.data
from omegaconf.dictconfig import DictConfig

from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import (
Expand Down Expand Up @@ -97,7 +98,15 @@ def _build_dataset(current_data_prefix, current_num_samples):


def build_train_valid_test_datasets(
cls, cfg, data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, tokenizer,
cls,
cfg,
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
seq_length,
seed,
tokenizer,
):
if isinstance(data_prefix, DictConfig):
assert (
Expand All @@ -123,7 +132,7 @@ def build_train_valid_test_datasets(
cfg=cfg,
data_prefix=data_prefix["validation"],
data_impl=data_impl,
num_samples=int(train_valid_test_num_samples[0]),
num_samples=int(train_valid_test_num_samples[1]),
seq_length=seq_length,
seed=seed,
tokenizer=tokenizer,
Expand All @@ -134,7 +143,7 @@ def build_train_valid_test_datasets(
cfg=cfg,
data_prefix=data_prefix["test"],
data_impl=data_impl,
num_samples=int(train_valid_test_num_samples[0]),
num_samples=int(train_valid_test_num_samples[2]),
seq_length=seq_length,
seed=seed,
tokenizer=tokenizer,
Expand Down Expand Up @@ -202,7 +211,15 @@ def build_train_valid_test_datasets(


def _build_train_valid_test_datasets(
cls, cfg, data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, tokenizer,
cls,
cfg,
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
seq_length,
seed,
tokenizer,
):
"""Build train, valid, and test datasets."""

Expand Down Expand Up @@ -318,8 +335,7 @@ def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, i


def collate_with_pad_to_max_batch(max_seqlen, tokenizer_eos_id, cfg, generate_masks_and_position_ids=True):
"""collate function that pads each sequence to the max in the batch
"""
"""collate function that pads each sequence to the max in the batch"""
return partial(
collate_with_batch_max_sequence_length,
response_token_length=max_seqlen,
Expand Down
Loading

0 comments on commit 446f7f0

Please sign in to comment.