diff --git a/docs/zh_cn/acceleration/train_extreme_long_sequence.rst b/docs/zh_cn/acceleration/train_extreme_long_sequence.rst index 65b364ad8..89efcaa3f 100644 --- a/docs/zh_cn/acceleration/train_extreme_long_sequence.rst +++ b/docs/zh_cn/acceleration/train_extreme_long_sequence.rst @@ -219,7 +219,7 @@ XTuner 中的序列并行设计思路参考了 DeepSpeed 的工作 `DeepSpeed Ul - 适配序列并行的 Data Sampler (SequenceParallelSampler) - 数据 Pad 与切分 (pad_for_sequence_parallel, split_for_sequence_parallel) - 适配序列并行的 Attention (dispatch_modules) -- reduce loss 以正确打印训练损失 (reduce_sequence_parallel_loss) +- rescale loss 使得在使用序列并行时 backward 梯度与数据并行 (DP) 保持一致 (rescale_sp_loss) 分布式环境初始化 ------------------- @@ -303,20 +303,16 @@ XTuner 提供了 dispatch_modules 接口以支持修改模型 Attention 的计 .. tip:: 上述过程在 ``xtuner/model/sft.py`` 中实现。 -Reduce Loss +Rescale Loss ------------- -这个 API 对于保证训练的正确性不是必须的,但对于观测模型训练状态,打印训练 loss 是非常有用的。 +由于不同的 sp rank 上计算 loss 的 tokens 数量各不相同,因此在数据并行 (DP) 梯度同步过程中,简单的不同 rank 的梯度取平均对于序列并行 (SP) 是不合理的。XTuner 提供 `rescale_sp_loss` API 来确保序列并行场景与数据并行场景的参数梯度保持一致。 .. code-block:: python - from xtuner.parallel.sequence import reduce_sequence_parallel_loss + from xtuner.parallel.sequence import rescale_sp_loss, get_sequence_parallel_group outputs = llm(input_ids=input_ids, labels=labels, **kwargs) - num_tokens_per_rank = (labels != -100).sum() - # Suppose sequence parallel world size equals to 4, - # losses on rank0, rank1, rank2, rank3 are different. - loss = reduce_sequence_parallel_loss(outputs.loss, num_tokens_per_rank) - # After loss reduction, losses on rank0, rank1, rank2, rank3 are the same. + rescaled_loss = rescale_sp_loss(outputs.loss, labels, sp_group) .. tip:: 上述过程在 ``xtuner/model/sft.py`` 中实现。 diff --git a/docs/zh_cn/user_guides/sequence_parallel.md b/docs/zh_cn/user_guides/sequence_parallel.md index ce4beed64..f3c9c72c1 100644 --- a/docs/zh_cn/user_guides/sequence_parallel.md +++ b/docs/zh_cn/user_guides/sequence_parallel.md @@ -95,7 +95,7 @@ model = dict( - 数据 Pad (pad_for_sequence_parallel) - 数据切分 (split_for_sequence_parallel) - 适配序列并行的 Attention (dispatch_modules) -- reduce loss 以正确打印训练损失 (reduce_sequence_parallel_loss) +- rescale loss 使得在使用序列并行时 backward 梯度与数据并行 (DP) 保持一致 (rescale_sp_loss) ### 序列并行分布式环境初始化 @@ -176,16 +176,12 @@ dispatch_modules(model) ### Reduce Loss 以正确打印训练损失 -这个 API 对于保证训练的正确性不是必须的,但对于观测模型训练状态,打印训练 loss 是非常有用的。 +由于不同的 sp rank 上计算 loss 的 tokens 数量各不相同,因此在数据并行 (DP) 梯度同步过程中,简单的不同 rank 的梯度取平均对于序列并行 (SP) 是不合理的。XTuner 提供 `rescale_sp_loss` API 来确保序列并行场景与数据并行场景的参数梯度保持一致。 ```python -from xtuner.parallel.sequence import reduce_sequence_parallel_loss +from xtuner.parallel.sequence import rescale_sp_loss, get_sequence_parallel_group outputs = llm(input_ids=input_ids, labels=labels, **kwargs) -num_tokens_per_rank = (labels != -100).sum() -# Suppose sequence parallel world size equals to 4, -# losses on rank0, rank1, rank2, rank3 are different. -loss = reduce_sequence_parallel_loss(outputs.loss, num_tokens_per_rank) -# After loss reduction, losses on rank0, rank1, rank2, rank3 are the same. +rescaled_loss = rescale_sp_loss(outputs.loss, labels, sp_group) ``` 上述过程在 xtuner/model/sft.py 中实现。 diff --git a/xtuner/configs/internlm/internlm2_5_chat_20b/internlm2_5_chat_20b_alpaca_e3.py b/xtuner/configs/internlm/internlm2_5_chat_20b/internlm2_5_chat_20b_alpaca_e3.py index f67fc1a22..c8dc9e8f3 100644 --- a/xtuner/configs/internlm/internlm2_5_chat_20b/internlm2_5_chat_20b_alpaca_e3.py +++ b/xtuner/configs/internlm/internlm2_5_chat_20b/internlm2_5_chat_20b_alpaca_e3.py @@ -1,14 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -import torch from datasets import load_dataset from mmengine.dataset import DefaultSampler from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, LoggerHook, ParamSchedulerHook) from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR -from peft import LoraConfig from torch.optim import AdamW -from transformers import (AutoModelForCausalLM, AutoTokenizer, - BitsAndBytesConfig) +from transformers import AutoModelForCausalLM, AutoTokenizer from xtuner.dataset import process_hf_dataset from xtuner.dataset.collate_fns import default_collate_fn diff --git a/xtuner/model/sft.py b/xtuner/model/sft.py index 522950489..eb4128e22 100644 --- a/xtuner/model/sft.py +++ b/xtuner/model/sft.py @@ -15,7 +15,8 @@ from xtuner.parallel.sequence import (get_sequence_parallel_group, get_sequence_parallel_world_size, - reduce_sequence_parallel_loss, + reduce_sp_loss_for_debug, + rescale_sp_loss, split_for_sequence_parallel) from xtuner.registry import BUILDER from .modules import dispatch_modules @@ -79,7 +80,6 @@ def __init__(self, tokenizer=None, max_position_embeddings=None): super().__init__() - self.llm = self.build_llm_from_cfg(llm, use_varlen_attn, max_position_embeddings) @@ -88,7 +88,6 @@ def __init__(self, tokenizer = BUILDER.build(tokenizer) smart_tokenizer_and_embedding_resize(tokenizer, self.llm) - self.llm.config.use_cache = False if use_activation_checkpointing: # For backward compatibility if hasattr(self.llm, 'enable_input_require_grads'): @@ -116,6 +115,8 @@ def __init__(self, # the sequence. self.use_varlen_attn = use_varlen_attn + self.debug_sp = False + def build_llm_from_cfg(self, llm_cfg, use_varlen_attn, max_position_embeddings): # For forward @@ -288,11 +289,17 @@ def _compute_sequence_parallel_loss(self, data): data = self._split_for_sequence_parallel(data) outputs = self.llm(**data) labels = data['labels'] - num_tokens = (labels != -100).sum() + sp_group = get_sequence_parallel_group() - loss = reduce_sequence_parallel_loss(outputs.loss, num_tokens, - sp_group) - return {'loss': loss} + loss = rescale_sp_loss(outputs.loss, labels, sp_group) + output = {'loss': loss} + if self.debug_sp: + reduced_loss = reduce_sp_loss_for_debug(outputs.loss, labels, + sp_group) + # string `loss` can not be a part of the key in output dict + # https://github.com/open-mmlab/mmengine/blob/main/mmengine/model/base_model/base_model.py#L174 # noqa: E501 + output['reduced_l'] = reduced_loss + return output def compute_loss(self, data, data_samples=None): if get_sequence_parallel_world_size() > 1: diff --git a/xtuner/parallel/sequence/__init__.py b/xtuner/parallel/sequence/__init__.py index 6e2992f78..01aaafd6d 100644 --- a/xtuner/parallel/sequence/__init__.py +++ b/xtuner/parallel/sequence/__init__.py @@ -9,7 +9,7 @@ split_forward_gather_backward) from .data_collate import (pad_cumulative_len_for_sequence_parallel, pad_for_sequence_parallel) -from .reduce_loss import reduce_sequence_parallel_loss +from .reduce_loss import reduce_sp_loss_for_debug, rescale_sp_loss from .sampler import SequenceParallelSampler from .setup_distributed import (get_data_parallel_group, get_data_parallel_rank, @@ -31,11 +31,12 @@ 'init_sequence_parallel', 'get_sequence_parallel_group', 'get_sequence_parallel_world_size', 'get_sequence_parallel_rank', 'get_data_parallel_group', 'get_data_parallel_world_size', - 'get_data_parallel_rank', 'reduce_sequence_parallel_loss', 'init_dist', + 'get_data_parallel_rank', 'init_dist', 'all_to_all', 'gather_for_sequence_parallel', 'split_forward_gather_backward', 'gather_forward_split_backward', 'get_inner_sequence_parallel_group', 'get_inner_sequence_parallel_rank', 'get_inner_sequence_parallel_world_size', 'init_inner_sequence_parallel', 'is_inner_sequence_parallel_initialized', - 'pad_cumulative_len_for_sequence_parallel' + 'pad_cumulative_len_for_sequence_parallel', 'rescale_sp_loss', + 'reduce_sp_loss_for_debug' ] diff --git a/xtuner/parallel/sequence/reduce_loss.py b/xtuner/parallel/sequence/reduce_loss.py index fb37242a3..7ece7fd67 100644 --- a/xtuner/parallel/sequence/reduce_loss.py +++ b/xtuner/parallel/sequence/reduce_loss.py @@ -1,34 +1,55 @@ +import copy + import torch import torch.distributed as dist from .setup_distributed import get_sequence_parallel_group -class _ReduceLoss(torch.autograd.Function): +def rescale_sp_loss(loss_per_sp_rank, + labels_per_sp_rank, + sp_group: dist.ProcessGroup = None, + ignore_index=-100): + if sp_group is None: + sp_group = get_sequence_parallel_group() + + if (sp_group is None) or (dist.get_world_size(sp_group) == 1): + return loss_per_sp_rank + + shift_labels = labels_per_sp_rank[..., 1:].view(-1) + active_tokens = (shift_labels != ignore_index).long().sum() + global_active_tokens = copy.deepcopy(active_tokens) + dist.all_reduce(global_active_tokens, group=sp_group) + loss_weight = active_tokens / global_active_tokens * dist.get_world_size( + group=sp_group) - @staticmethod - def forward(ctx, mean_loss, loss_scale, process_group): - ctx.mode = process_group - if loss_scale == 0: - # convert nan to 0 just for logging - mean_loss = torch.nan_to_num(mean_loss) - loss_sum = mean_loss * loss_scale - dist.all_reduce(loss_sum, group=process_group) - dist.all_reduce(loss_scale, group=process_group) - loss = loss_sum / loss_scale - return loss + if active_tokens == 0: + # convert nan to 0 just for logging + loss_per_sp_rank = torch.nan_to_num(loss_per_sp_rank) - @staticmethod - def backward(ctx, grad_output): - return grad_output, None, None + return loss_per_sp_rank * loss_weight -def reduce_sequence_parallel_loss(mean_loss, - loss_scale, - sp_group: dist.ProcessGroup = None): - if dist.get_world_size(sp_group) == 1: - return mean_loss +def reduce_sp_loss_for_debug(loss_per_sp_rank, + labels_per_sp_rank, + sp_group: dist.ProcessGroup = None, + ignore_index=-100): + # Reduce loss to check whether the training losses is different + # when using sp. This function is only used for debugging if sp_group is None: - # avoid bc breaking sp_group = get_sequence_parallel_group() - return _ReduceLoss.apply(mean_loss, loss_scale, sp_group) + + if (sp_group is None) or (dist.get_world_size(sp_group) == 1): + return loss_per_sp_rank + + shift_labels = labels_per_sp_rank[..., 1:].view(-1) + active_tokens = (shift_labels != ignore_index).long().sum() + if active_tokens == 0: + # convert nan to 0 just for logging + loss_per_sp_rank = torch.nan_to_num(loss_per_sp_rank) + + loss_sum = loss_per_sp_rank * active_tokens + global_active_tokens = copy.deepcopy(active_tokens) + dist.all_reduce(loss_sum, group=sp_group) + dist.all_reduce(global_active_tokens, group=sp_group) + return loss_sum / global_active_tokens