Skip to content

Commit

Permalink
Add more docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
hebiao064 committed Dec 13, 2024
1 parent ab08ab6 commit 98aa519
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions src/liger_kernel/chunked_loss/kto_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,32 @@ def preference_loss_fn(
beta=0.1,
):
"""
Paper: https://arxiv.org/abs/2402.01306
Implements the Kahneman-Tversky Optimization (KTO) loss function.
Paper: "KTO: Model Alignment as Prospect Theory-Guided Optimization"
https://arxiv.org/abs/2402.01306
KTO loss is inspired by prospect theory (https://en.wikipedia.org/wiki/Prospect_theory)
from behavioral economics, which models how humans make decisions under uncertainty.
The loss function is asymmetric, treating gains and losses differently, similar to
human decision-making patterns.
Formula:
When y is chosen:
L_KTO = 1 - σ(β * (log[π(x)/π₀(x)] - KL(π||π₀)_y))
When y is rejected:
L_KTO = 1 - σ(β * (KL(π||π₀)_y - log[π(x)/π₀(x)]))
Where:
- σ: Sigmoid function
- β: Temperature parameter
- KL(π||π₀)_y is KL divergence estimated using the rejected response y
- β: Temperature parameter controlling the strength of the preference signal
- π(x): Policy (current model)
- π₀(x): Reference policy (reference model)
- KL(π||π₀)_y: KL divergence estimated using the rejected response y
The loss encourages the model to:
1. Assign higher probability to chosen responses
2. Assign lower probability to rejected responses
3. Maintain reasonable distance from the reference model
Args:
chosen_logps: Log probabilities of chosen tokens (batch_size,)
Expand All @@ -35,6 +52,12 @@ def preference_loss_fn(
ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
beta: Weight for the direct preference loss
Returns:
Tuple of (loss, chosen_rewards, rejected_rewards):
- loss: The KTO loss value
- chosen_rewards: Reward signals for chosen responses (detached)
- rejected_rewards: Reward signals for rejected responses (detached)
"""
if ref_chosen_logps is None:
ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
Expand All @@ -60,7 +83,8 @@ def preference_loss_fn(
rejected_rewards = beta * rejected_logratios.detach()

return (
losses.sum() / (full_target.shape[0] // 2),
# We don't divide by 2 because KTO Loss doesn't need pair-wise examples
losses.sum() / (full_target.shape[0]),
chosen_rewards,
rejected_rewards,
)
Expand Down

0 comments on commit 98aa519

Please sign in to comment.