diff --git a/src/llm_training/lms/clm/clm.py b/src/llm_training/lms/clm/clm.py index 7314292..36fd807 100644 --- a/src/llm_training/lms/clm/clm.py +++ b/src/llm_training/lms/clm/clm.py @@ -1,6 +1,6 @@ import logging from contextlib import nullcontext -from typing import Any +from typing import TypedDict import torch import torch.nn.functional as F @@ -11,8 +11,9 @@ from llm_training.lightning.strategy import FSDP2Strategy from llm_training.lms.base_lm import BaseLightningModule from llm_training.lms.protos import CausalLMProto -from llm_training.lms.utils import get_model -from llm_training.metrics import ConsumedSamples, ConsumedTokens, Perplexity +from llm_training.lms.utils import DataFetcherProxy, get_model +from llm_training.metrics import (ConsumedSamples, ConsumedTokens, Loss, + Perplexity) from llm_training.models.base_model.base_model import BaseModel from llm_training.ops import shift_labels from llm_training.ops.liger_kernel import cross_entropy @@ -97,6 +98,8 @@ def configure_model(self) -> None: ignore_index=self.config.ignore_index, process_group=process_group ) + + self.loss_metric = Loss(process_group=process_group) self.model = get_model(self.config.model) @@ -110,20 +113,34 @@ def configure_model(self) -> None: def on_fsdp_parallelize_model(self, **kwargs) -> None: self.model.parallelize(**kwargs) - def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + def compute_loss( + self, + logits: torch.Tensor, + labels: torch.Tensor, + num_tokens_in_batch: int | None = None + ) -> torch.Tensor: + reduction = 'mean' if num_tokens_in_batch is None else 'sum' + if isinstance(self.strategy, FSDP2Strategy) and self.strategy.tp_size > 1: with loss_parallel(): - return F.cross_entropy( + loss = F.cross_entropy( logits.flatten(end_dim=1), labels.flatten(end_dim=1), - ignore_index=self.config.ignore_index + ignore_index=self.config.ignore_index, + reduction=reduction ) + else: + loss = cross_entropy( + logits=logits, + labels=labels, + ignore_index=self.config.ignore_index, + reduction=reduction + ) - return cross_entropy( - logits=logits, - labels=labels, - ignore_index=self.config.ignore_index - ) + if num_tokens_in_batch is not None: + loss /= num_tokens_in_batch / self.trainer.accumulate_grad_batches / self.trainer.world_size + + return loss def backward(self, loss: torch.Tensor, *args, **kwargs) -> None: backward_ctx = nullcontext() @@ -133,7 +150,35 @@ def backward(self, loss: torch.Tensor, *args, **kwargs) -> None: with backward_ctx: super().backward(loss, *args, **kwargs) - def training_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int) -> torch.Tensor: + def get_next_batches(self, batch: "_Batch", n: int) -> list["_Batch"]: + batches = [] + if n > 0: + batches.append(batch) + batches += [x[0] for x in self.data_fetcher.prefetch(n - 1)] + return batches + + def get_num_tokens_in_batch(self, batch: "_Batch", batch_idx: int) -> int | None: + if batch_idx % self.trainer.accumulate_grad_batches == 0: + batches = self.get_next_batches( + batch, + min(self.trainer.accumulate_grad_batches, self.trainer.num_training_batches - batch_idx) + ) + num_tokens_in_batch = sum([ + x['labels'].ne(self.config.ignore_index).sum() for x in batches + ]) + self._num_tokens_in_batch = self.trainer.strategy.reduce( + num_tokens_in_batch.to(self.device), + reduce_op='sum' + ) + return self._num_tokens_in_batch + + def on_train_epoch_start(self) -> None: + fit_loop = self.trainer.fit_loop + if not isinstance(fit_loop._data_fetcher, DataFetcherProxy): + fit_loop._data_fetcher = DataFetcherProxy(fit_loop._data_fetcher) + self.data_fetcher = fit_loop._data_fetcher + + def training_step(self, batch: "_Batch", batch_idx: int) -> torch.Tensor: labels = shift_labels(batch['labels'], self.config.ignore_index) if self.config.neftune_alpha is not None: @@ -142,7 +187,7 @@ def training_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int) -> outputs = self.model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], - position_ids=batch.get('position_ids', None) + position_ids=batch['position_ids'] ) logits = outputs.logits.float() @@ -150,14 +195,22 @@ def training_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int) -> self.log('NEFTune Alpha', self.config.neftune_alpha) self._current_attention_mask = None - loss = self.compute_loss(logits, labels) - - self.log('loss', loss, prog_bar=True, logger=False) - self.log('Loss/Train/Step', loss) - - if self.config.log_perplexity: - self.train_perplexity(loss) - self.log('Perplexity/Train/Step', self.train_perplexity) + loss = self.compute_loss(logits, labels, self.get_num_tokens_in_batch(batch, batch_idx)) + + self.loss_metric.update(loss) + epoch_loop = self.trainer.fit_loop.epoch_loop + if ( + epoch_loop._accumulated_batches_reached() + or epoch_loop._num_ready_batches_reached() + ): + reduced_loss = self.loss_metric.compute() + self.loss_metric.reset() + self.log('loss', reduced_loss, prog_bar=True, logger=False) + self.log('Loss/Train/Step', reduced_loss) + + if self.config.log_perplexity: + self.train_perplexity(reduced_loss) + self.log('Perplexity/Train/Step', self.train_perplexity) self.consumed_samples.update(labels) self.consumed_tokens.update(labels) @@ -167,7 +220,7 @@ def training_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int) -> }) return loss - def validation_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int, dataloader_idx: int = 0): + def validation_step(self, batch: "_Batch", batch_idx: int, dataloader_idx: int = 0): batch_size = batch['input_ids'].size(0) labels = shift_labels(batch['labels'], self.config.ignore_index) outputs = self.model( @@ -186,3 +239,10 @@ def validation_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int, def get_model(self) -> BaseModel: return self.model + + +class _Batch(TypedDict): + input_ids: torch.Tensor + attention_mask: torch.Tensor + position_ids: torch.Tensor + labels: torch.Tensor diff --git a/src/llm_training/lms/utils.py b/src/llm_training/lms/utils.py index ed7580b..3544c79 100644 --- a/src/llm_training/lms/utils.py +++ b/src/llm_training/lms/utils.py @@ -1,5 +1,7 @@ from typing import Callable +from lightning.pytorch.loops.fetchers import _DataFetcher + from llm_training.lms.model_provider import ModelProvider from llm_training.models.base_model.base_model import BaseModel @@ -10,3 +12,30 @@ def get_model(model_or_provider: ModelType) -> BaseModel: if isinstance(model_or_provider, BaseModel): return model_or_provider return model_or_provider() + + +class DataFetcherProxy: + def __init__(self, data_fetcher: _DataFetcher) -> None: + self.data_fetcher = data_fetcher + self.prefetched_batches = [] + + def __iter__(self): + return self.data_fetcher.__iter__() + + def __next__(self): + if self.prefetched_batches: + return self.prefetched_batches.pop(0) + return next(self.data_fetcher) + + def prefetch(self, n: int): + while len(self.prefetched_batches) < n: + x = next(self.data_fetcher.iterator) + self.prefetched_batches.append(x) + return self.prefetched_batches[:n] + + def reset(self): + self.prefetched_batches.clear() + return self.data_fetcher.reset(self) + + def __getattr__(self, name): + return getattr(self.data_fetcher, name) diff --git a/src/llm_training/metrics/__init__.py b/src/llm_training/metrics/__init__.py index 2bbd2f7..e3f31b6 100644 --- a/src/llm_training/metrics/__init__.py +++ b/src/llm_training/metrics/__init__.py @@ -1,3 +1,4 @@ from .consumed_samples import ConsumedSamples from .consumed_tokens import ConsumedTokens +from .loss import Loss from .perplexity import Perplexity diff --git a/src/llm_training/metrics/loss.py b/src/llm_training/metrics/loss.py new file mode 100644 index 0000000..3f06cd6 --- /dev/null +++ b/src/llm_training/metrics/loss.py @@ -0,0 +1,25 @@ +import torch + +from .metric import Metric + + +class Loss(Metric): + is_differentiable: bool = True + higher_is_better: bool = False + full_state_update: bool = False + + loss: torch.Tensor + count: torch.Tensor + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.add_state('loss', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('count', default=torch.tensor(0.0), dist_reduce_fx='sum') + + def update(self, loss: torch.Tensor): + self.loss += loss + self.count += 1 + + def compute(self): + return self.loss / self.count diff --git a/src/llm_training/ops/liger_kernel/cross_entropy_op.py b/src/llm_training/ops/liger_kernel/cross_entropy_op.py index 827399a..93eadb8 100644 --- a/src/llm_training/ops/liger_kernel/cross_entropy_op.py +++ b/src/llm_training/ops/liger_kernel/cross_entropy_op.py @@ -11,10 +11,11 @@ def cross_entropy( logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, - reduction: Literal['mean'] = 'mean' + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: Literal["none", "mean", "sum"] = 'mean', + softcap: float | None = None ) -> torch.Tensor: - assert reduction == 'mean' - if logits.dim() == 3 and labels.dim() == 2: logits = logits.flatten(end_dim=1) labels = labels.flatten(end_dim=1) @@ -23,13 +24,19 @@ def cross_entropy( return F.cross_entropy( logits, labels, - ignore_index=ignore_index + ignore_index=ignore_index, + reduction=reduction, + label_smoothing=label_smoothing ) return LigerCrossEntropyFunction.apply( logits, labels, - ignore_index + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap )[0] @@ -37,11 +44,13 @@ def fused_linear_cross_entropy( hidden_states: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, + bias: torch.Tensor | None = None, ignore_index: int = -100, - reduction: Literal['mean'] = 'mean' + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: Literal["none", "mean", "sum"] = 'mean', + softcap: float | None = None ) -> torch.Tensor: - assert reduction == 'mean' - if hidden_states.dim() == 3 and labels.dim() == 2: hidden_states = hidden_states.flatten(end_dim=1) labels = labels.flatten(end_dim=1) @@ -50,5 +59,10 @@ def fused_linear_cross_entropy( hidden_states, weight, labels, - ignore_index + bias, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap )