Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

修正梯度累加及分佈式訓練的 Loss 計算錯誤 #16

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 82 additions & 22 deletions src/llm_training/lms/clm/clm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from contextlib import nullcontext
from typing import Any
from typing import TypedDict

import torch
import torch.nn.functional as F
Expand All @@ -11,8 +11,9 @@
from llm_training.lightning.strategy import FSDP2Strategy
from llm_training.lms.base_lm import BaseLightningModule
from llm_training.lms.protos import CausalLMProto
from llm_training.lms.utils import get_model
from llm_training.metrics import ConsumedSamples, ConsumedTokens, Perplexity
from llm_training.lms.utils import DataFetcherProxy, get_model
from llm_training.metrics import (ConsumedSamples, ConsumedTokens, Loss,
Perplexity)
from llm_training.models.base_model.base_model import BaseModel
from llm_training.ops import shift_labels
from llm_training.ops.liger_kernel import cross_entropy
Expand Down Expand Up @@ -97,6 +98,8 @@ def configure_model(self) -> None:
ignore_index=self.config.ignore_index,
process_group=process_group
)

self.loss_metric = Loss(process_group=process_group)

self.model = get_model(self.config.model)

Expand All @@ -110,20 +113,34 @@ def configure_model(self) -> None:
def on_fsdp_parallelize_model(self, **kwargs) -> None:
self.model.parallelize(**kwargs)

def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
def compute_loss(
self,
logits: torch.Tensor,
labels: torch.Tensor,
num_tokens_in_batch: int | None = None
) -> torch.Tensor:
reduction = 'mean' if num_tokens_in_batch is None else 'sum'

if isinstance(self.strategy, FSDP2Strategy) and self.strategy.tp_size > 1:
with loss_parallel():
return F.cross_entropy(
loss = F.cross_entropy(
logits.flatten(end_dim=1),
labels.flatten(end_dim=1),
ignore_index=self.config.ignore_index
ignore_index=self.config.ignore_index,
reduction=reduction
)
else:
loss = cross_entropy(
logits=logits,
labels=labels,
ignore_index=self.config.ignore_index,
reduction=reduction
)

return cross_entropy(
logits=logits,
labels=labels,
ignore_index=self.config.ignore_index
)
if num_tokens_in_batch is not None:
loss /= num_tokens_in_batch / self.trainer.accumulate_grad_batches / self.trainer.world_size

return loss

def backward(self, loss: torch.Tensor, *args, **kwargs) -> None:
backward_ctx = nullcontext()
Expand All @@ -133,7 +150,35 @@ def backward(self, loss: torch.Tensor, *args, **kwargs) -> None:
with backward_ctx:
super().backward(loss, *args, **kwargs)

def training_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int) -> torch.Tensor:
def get_next_batches(self, batch: "_Batch", n: int) -> list["_Batch"]:
batches = []
if n > 0:
batches.append(batch)
batches += [x[0] for x in self.data_fetcher.prefetch(n - 1)]
return batches

def get_num_tokens_in_batch(self, batch: "_Batch", batch_idx: int) -> int | None:
if batch_idx % self.trainer.accumulate_grad_batches == 0:
batches = self.get_next_batches(
batch,
min(self.trainer.accumulate_grad_batches, self.trainer.num_training_batches - batch_idx)
)
num_tokens_in_batch = sum([
x['labels'].ne(self.config.ignore_index).sum() for x in batches
])
self._num_tokens_in_batch = self.trainer.strategy.reduce(
num_tokens_in_batch.to(self.device),
reduce_op='sum'
)
return self._num_tokens_in_batch

def on_train_epoch_start(self) -> None:
fit_loop = self.trainer.fit_loop
if not isinstance(fit_loop._data_fetcher, DataFetcherProxy):
fit_loop._data_fetcher = DataFetcherProxy(fit_loop._data_fetcher)
self.data_fetcher = fit_loop._data_fetcher

def training_step(self, batch: "_Batch", batch_idx: int) -> torch.Tensor:
labels = shift_labels(batch['labels'], self.config.ignore_index)

if self.config.neftune_alpha is not None:
Expand All @@ -142,22 +187,30 @@ def training_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int) ->
outputs = self.model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
position_ids=batch.get('position_ids', None)
position_ids=batch['position_ids']
)
logits = outputs.logits.float()

if self.config.neftune_alpha is not None:
self.log('NEFTune Alpha', self.config.neftune_alpha)
self._current_attention_mask = None

loss = self.compute_loss(logits, labels)

self.log('loss', loss, prog_bar=True, logger=False)
self.log('Loss/Train/Step', loss)

if self.config.log_perplexity:
self.train_perplexity(loss)
self.log('Perplexity/Train/Step', self.train_perplexity)
loss = self.compute_loss(logits, labels, self.get_num_tokens_in_batch(batch, batch_idx))

self.loss_metric.update(loss)
epoch_loop = self.trainer.fit_loop.epoch_loop
if (
epoch_loop._accumulated_batches_reached()
or epoch_loop._num_ready_batches_reached()
):
reduced_loss = self.loss_metric.compute()
self.loss_metric.reset()
self.log('loss', reduced_loss, prog_bar=True, logger=False)
self.log('Loss/Train/Step', reduced_loss)

if self.config.log_perplexity:
self.train_perplexity(reduced_loss)
self.log('Perplexity/Train/Step', self.train_perplexity)

self.consumed_samples.update(labels)
self.consumed_tokens.update(labels)
Expand All @@ -167,7 +220,7 @@ def training_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int) ->
})
return loss

def validation_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int, dataloader_idx: int = 0):
def validation_step(self, batch: "_Batch", batch_idx: int, dataloader_idx: int = 0):
batch_size = batch['input_ids'].size(0)
labels = shift_labels(batch['labels'], self.config.ignore_index)
outputs = self.model(
Expand All @@ -186,3 +239,10 @@ def validation_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int,

def get_model(self) -> BaseModel:
return self.model


class _Batch(TypedDict):
input_ids: torch.Tensor
attention_mask: torch.Tensor
position_ids: torch.Tensor
labels: torch.Tensor
29 changes: 29 additions & 0 deletions src/llm_training/lms/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Callable

from lightning.pytorch.loops.fetchers import _DataFetcher

from llm_training.lms.model_provider import ModelProvider
from llm_training.models.base_model.base_model import BaseModel

Expand All @@ -10,3 +12,30 @@ def get_model(model_or_provider: ModelType) -> BaseModel:
if isinstance(model_or_provider, BaseModel):
return model_or_provider
return model_or_provider()


class DataFetcherProxy:
def __init__(self, data_fetcher: _DataFetcher) -> None:
self.data_fetcher = data_fetcher
self.prefetched_batches = []

def __iter__(self):
return self.data_fetcher.__iter__()

def __next__(self):
if self.prefetched_batches:
return self.prefetched_batches.pop(0)
return next(self.data_fetcher)

def prefetch(self, n: int):
while len(self.prefetched_batches) < n:
x = next(self.data_fetcher.iterator)
self.prefetched_batches.append(x)
return self.prefetched_batches[:n]

def reset(self):
self.prefetched_batches.clear()
return self.data_fetcher.reset(self)

def __getattr__(self, name):
return getattr(self.data_fetcher, name)
1 change: 1 addition & 0 deletions src/llm_training/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .consumed_samples import ConsumedSamples
from .consumed_tokens import ConsumedTokens
from .loss import Loss
from .perplexity import Perplexity
25 changes: 25 additions & 0 deletions src/llm_training/metrics/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch

from .metric import Metric


class Loss(Metric):
is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False

loss: torch.Tensor
count: torch.Tensor

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.add_state('loss', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0.0), dist_reduce_fx='sum')

def update(self, loss: torch.Tensor):
self.loss += loss
self.count += 1

def compute(self):
return self.loss / self.count
32 changes: 23 additions & 9 deletions src/llm_training/ops/liger_kernel/cross_entropy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ def cross_entropy(
logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: int = -100,
reduction: Literal['mean'] = 'mean'
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
reduction: Literal["none", "mean", "sum"] = 'mean',
softcap: float | None = None
) -> torch.Tensor:
assert reduction == 'mean'

if logits.dim() == 3 and labels.dim() == 2:
logits = logits.flatten(end_dim=1)
labels = labels.flatten(end_dim=1)
Expand All @@ -23,25 +24,33 @@ def cross_entropy(
return F.cross_entropy(
logits,
labels,
ignore_index=ignore_index
ignore_index=ignore_index,
reduction=reduction,
label_smoothing=label_smoothing
)

return LigerCrossEntropyFunction.apply(
logits,
labels,
ignore_index
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap
)[0]


def fused_linear_cross_entropy(
hidden_states: torch.Tensor,
weight: torch.Tensor,
labels: torch.Tensor,
bias: torch.Tensor | None = None,
ignore_index: int = -100,
reduction: Literal['mean'] = 'mean'
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
reduction: Literal["none", "mean", "sum"] = 'mean',
softcap: float | None = None
) -> torch.Tensor:
assert reduction == 'mean'

if hidden_states.dim() == 3 and labels.dim() == 2:
hidden_states = hidden_states.flatten(end_dim=1)
labels = labels.flatten(end_dim=1)
Expand All @@ -50,5 +59,10 @@ def fused_linear_cross_entropy(
hidden_states,
weight,
labels,
ignore_index
bias,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap
)