From 441d602de8d453753b38928989b9c6dcfdb4812a Mon Sep 17 00:00:00 2001 From: ShinoharaHare Date: Sun, 5 Jan 2025 04:18:28 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=20CLM=20Loss=20=E5=9C=A8?= =?UTF-8?q?=20GA=20=E5=8F=8A=20DP=20=E6=99=82=E8=A8=88=E7=AE=97=E9=8C=AF?= =?UTF-8?q?=E8=AA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_training/lms/clm/clm.py | 95 +++++++++++++++---- src/llm_training/lms/utils.py | 29 ++++++ src/llm_training/metrics/__init__.py | 1 + src/llm_training/metrics/loss.py | 25 +++++ .../ops/liger_kernel/cross_entropy_op.py | 32 +++++-- 5 files changed, 154 insertions(+), 28 deletions(-) create mode 100644 src/llm_training/metrics/loss.py diff --git a/src/llm_training/lms/clm/clm.py b/src/llm_training/lms/clm/clm.py index 7314292..20ccf49 100644 --- a/src/llm_training/lms/clm/clm.py +++ b/src/llm_training/lms/clm/clm.py @@ -1,6 +1,6 @@ import logging from contextlib import nullcontext -from typing import Any +from typing import TypedDict import torch import torch.nn.functional as F @@ -11,8 +11,9 @@ from llm_training.lightning.strategy import FSDP2Strategy from llm_training.lms.base_lm import BaseLightningModule from llm_training.lms.protos import CausalLMProto -from llm_training.lms.utils import get_model -from llm_training.metrics import ConsumedSamples, ConsumedTokens, Perplexity +from llm_training.lms.utils import DataFetcherProxy, get_model +from llm_training.metrics import (ConsumedSamples, ConsumedTokens, Loss, + Perplexity) from llm_training.models.base_model.base_model import BaseModel from llm_training.ops import shift_labels from llm_training.ops.liger_kernel import cross_entropy @@ -97,6 +98,8 @@ def configure_model(self) -> None: ignore_index=self.config.ignore_index, process_group=process_group ) + + self.loss_metric = Loss(process_group=process_group) self.model = get_model(self.config.model) @@ -110,20 +113,34 @@ def configure_model(self) -> None: def on_fsdp_parallelize_model(self, **kwargs) -> None: self.model.parallelize(**kwargs) - def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + def compute_loss( + self, + logits: torch.Tensor, + labels: torch.Tensor, + num_tokens_in_batch: int | None = None + ) -> torch.Tensor: + reduction = 'mean' if num_tokens_in_batch is None else 'sum' + if isinstance(self.strategy, FSDP2Strategy) and self.strategy.tp_size > 1: with loss_parallel(): - return F.cross_entropy( + loss = F.cross_entropy( logits.flatten(end_dim=1), labels.flatten(end_dim=1), - ignore_index=self.config.ignore_index + ignore_index=self.config.ignore_index, + reduction=reduction ) + else: + loss = cross_entropy( + logits=logits, + labels=labels, + ignore_index=self.config.ignore_index, + reduction=reduction + ) - return cross_entropy( - logits=logits, - labels=labels, - ignore_index=self.config.ignore_index - ) + if num_tokens_in_batch is not None: + loss /= num_tokens_in_batch / self.trainer.accumulate_grad_batches / self.trainer.world_size + + return loss def backward(self, loss: torch.Tensor, *args, **kwargs) -> None: backward_ctx = nullcontext() @@ -133,7 +150,36 @@ def backward(self, loss: torch.Tensor, *args, **kwargs) -> None: with backward_ctx: super().backward(loss, *args, **kwargs) - def training_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int) -> torch.Tensor: + def get_next_batches(self, batch: "_Batch", n: int) -> list["_Batch"]: + batches = [] + if n > 0: + batches.append(batch) + batches += [x[0] for x in self.data_fetcher.prefetch(n - 1)] + return batches + + def get_num_tokens_in_batch(self, batch: "_Batch", batch_idx: int) -> int | None: + if batch_idx % self.trainer.accumulate_grad_batches == 0: + batches = self.get_next_batches( + batch, + min( + self.trainer.accumulate_grad_batches, self.trainer.num_training_batches - batch_idx) + ) + num_tokens_in_batch = sum([ + x['labels'].ne(self.config.ignore_index).sum() for x in batches + ]) + self._num_tokens_in_batch = self.trainer.strategy.reduce( + num_tokens_in_batch.to(self.device), + reduce_op='sum' + ) + return self._num_tokens_in_batch + + def on_train_epoch_start(self) -> None: + fit_loop = self.trainer.fit_loop + if not isinstance(fit_loop._data_fetcher, DataFetcherProxy): + fit_loop._data_fetcher = DataFetcherProxy(fit_loop._data_fetcher) + self.data_fetcher = fit_loop._data_fetcher + + def training_step(self, batch: "_Batch", batch_idx: int) -> torch.Tensor: labels = shift_labels(batch['labels'], self.config.ignore_index) if self.config.neftune_alpha is not None: @@ -150,14 +196,18 @@ def training_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int) -> self.log('NEFTune Alpha', self.config.neftune_alpha) self._current_attention_mask = None - loss = self.compute_loss(logits, labels) + loss = self.compute_loss(logits, labels, self.get_num_tokens_in_batch(batch, batch_idx)) - self.log('loss', loss, prog_bar=True, logger=False) - self.log('Loss/Train/Step', loss) + self.loss_metric.update(loss) + if not self.trainer.fit_loop._should_accumulate(): + reduced_loss = self.loss_metric.compute() + self.loss_metric.reset() + self.log('loss', reduced_loss, prog_bar=True, logger=False) + self.log('Loss/Train/Step', reduced_loss) - if self.config.log_perplexity: - self.train_perplexity(loss) - self.log('Perplexity/Train/Step', self.train_perplexity) + if self.config.log_perplexity: + self.train_perplexity(reduced_loss) + self.log('Perplexity/Train/Step', self.train_perplexity) self.consumed_samples.update(labels) self.consumed_tokens.update(labels) @@ -167,7 +217,7 @@ def training_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int) -> }) return loss - def validation_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int, dataloader_idx: int = 0): + def validation_step(self, batch: "_Batch", batch_idx: int, dataloader_idx: int = 0): batch_size = batch['input_ids'].size(0) labels = shift_labels(batch['labels'], self.config.ignore_index) outputs = self.model( @@ -186,3 +236,10 @@ def validation_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int, def get_model(self) -> BaseModel: return self.model + + +class _Batch(TypedDict): + input_ids: torch.Tensor + attention_mask: torch.Tensor + position_ids: torch.Tensor + labels: torch.Tensor diff --git a/src/llm_training/lms/utils.py b/src/llm_training/lms/utils.py index ed7580b..3544c79 100644 --- a/src/llm_training/lms/utils.py +++ b/src/llm_training/lms/utils.py @@ -1,5 +1,7 @@ from typing import Callable +from lightning.pytorch.loops.fetchers import _DataFetcher + from llm_training.lms.model_provider import ModelProvider from llm_training.models.base_model.base_model import BaseModel @@ -10,3 +12,30 @@ def get_model(model_or_provider: ModelType) -> BaseModel: if isinstance(model_or_provider, BaseModel): return model_or_provider return model_or_provider() + + +class DataFetcherProxy: + def __init__(self, data_fetcher: _DataFetcher) -> None: + self.data_fetcher = data_fetcher + self.prefetched_batches = [] + + def __iter__(self): + return self.data_fetcher.__iter__() + + def __next__(self): + if self.prefetched_batches: + return self.prefetched_batches.pop(0) + return next(self.data_fetcher) + + def prefetch(self, n: int): + while len(self.prefetched_batches) < n: + x = next(self.data_fetcher.iterator) + self.prefetched_batches.append(x) + return self.prefetched_batches[:n] + + def reset(self): + self.prefetched_batches.clear() + return self.data_fetcher.reset(self) + + def __getattr__(self, name): + return getattr(self.data_fetcher, name) diff --git a/src/llm_training/metrics/__init__.py b/src/llm_training/metrics/__init__.py index 2bbd2f7..e3f31b6 100644 --- a/src/llm_training/metrics/__init__.py +++ b/src/llm_training/metrics/__init__.py @@ -1,3 +1,4 @@ from .consumed_samples import ConsumedSamples from .consumed_tokens import ConsumedTokens +from .loss import Loss from .perplexity import Perplexity diff --git a/src/llm_training/metrics/loss.py b/src/llm_training/metrics/loss.py new file mode 100644 index 0000000..3f06cd6 --- /dev/null +++ b/src/llm_training/metrics/loss.py @@ -0,0 +1,25 @@ +import torch + +from .metric import Metric + + +class Loss(Metric): + is_differentiable: bool = True + higher_is_better: bool = False + full_state_update: bool = False + + loss: torch.Tensor + count: torch.Tensor + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.add_state('loss', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0.0), dist_reduce_fx='sum') + + def update(self, loss: torch.Tensor): + self.loss += loss + self.count += 1 + + def compute(self): + return self.loss / self.count diff --git a/src/llm_training/ops/liger_kernel/cross_entropy_op.py b/src/llm_training/ops/liger_kernel/cross_entropy_op.py index 827399a..93eadb8 100644 --- a/src/llm_training/ops/liger_kernel/cross_entropy_op.py +++ b/src/llm_training/ops/liger_kernel/cross_entropy_op.py @@ -11,10 +11,11 @@ def cross_entropy( logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, - reduction: Literal['mean'] = 'mean' + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: Literal["none", "mean", "sum"] = 'mean', + softcap: float | None = None ) -> torch.Tensor: - assert reduction == 'mean' - if logits.dim() == 3 and labels.dim() == 2: logits = logits.flatten(end_dim=1) labels = labels.flatten(end_dim=1) @@ -23,13 +24,19 @@ def cross_entropy( return F.cross_entropy( logits, labels, - ignore_index=ignore_index + ignore_index=ignore_index, + reduction=reduction, + label_smoothing=label_smoothing ) return LigerCrossEntropyFunction.apply( logits, labels, - ignore_index + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap )[0] @@ -37,11 +44,13 @@ def fused_linear_cross_entropy( hidden_states: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, + bias: torch.Tensor | None = None, ignore_index: int = -100, - reduction: Literal['mean'] = 'mean' + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: Literal["none", "mean", "sum"] = 'mean', + softcap: float | None = None ) -> torch.Tensor: - assert reduction == 'mean' - if hidden_states.dim() == 3 and labels.dim() == 2: hidden_states = hidden_states.flatten(end_dim=1) labels = labels.flatten(end_dim=1) @@ -50,5 +59,10 @@ def fused_linear_cross_entropy( hidden_states, weight, labels, - ignore_index + bias, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap ) From 72748a67f1490ec281db212655a5b2581aeab074 Mon Sep 17 00:00:00 2001 From: ShinoharaHare Date: Sun, 5 Jan 2025 23:02:50 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E8=AA=BF=E6=95=B4=E6=90=8D=E5=A4=B1?= =?UTF-8?q?=E8=A8=88=E7=AE=97=E7=9A=84=E7=B4=AF=E7=A9=8D=E6=A2=9D=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_training/lms/clm/clm.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/llm_training/lms/clm/clm.py b/src/llm_training/lms/clm/clm.py index 20ccf49..36fd807 100644 --- a/src/llm_training/lms/clm/clm.py +++ b/src/llm_training/lms/clm/clm.py @@ -161,8 +161,7 @@ def get_num_tokens_in_batch(self, batch: "_Batch", batch_idx: int) -> int | None if batch_idx % self.trainer.accumulate_grad_batches == 0: batches = self.get_next_batches( batch, - min( - self.trainer.accumulate_grad_batches, self.trainer.num_training_batches - batch_idx) + min(self.trainer.accumulate_grad_batches, self.trainer.num_training_batches - batch_idx) ) num_tokens_in_batch = sum([ x['labels'].ne(self.config.ignore_index).sum() for x in batches @@ -188,7 +187,7 @@ def training_step(self, batch: "_Batch", batch_idx: int) -> torch.Tensor: outputs = self.model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], - position_ids=batch.get('position_ids', None) + position_ids=batch['position_ids'] ) logits = outputs.logits.float() @@ -199,7 +198,11 @@ def training_step(self, batch: "_Batch", batch_idx: int) -> torch.Tensor: loss = self.compute_loss(logits, labels, self.get_num_tokens_in_batch(batch, batch_idx)) self.loss_metric.update(loss) - if not self.trainer.fit_loop._should_accumulate(): + epoch_loop = self.trainer.fit_loop.epoch_loop + if ( + epoch_loop._accumulated_batches_reached() + or epoch_loop._num_ready_batches_reached() + ): reduced_loss = self.loss_metric.compute() self.loss_metric.reset() self.log('loss', reduced_loss, prog_bar=True, logger=False)