diff --git a/src/liger_kernel/chunked_loss/kto_loss.py b/src/liger_kernel/chunked_loss/kto_loss.py index b0f6886b8..59d59cbf9 100644 --- a/src/liger_kernel/chunked_loss/kto_loss.py +++ b/src/liger_kernel/chunked_loss/kto_loss.py @@ -18,15 +18,32 @@ def preference_loss_fn( beta=0.1, ): """ - Paper: https://arxiv.org/abs/2402.01306 + Implements the Kahneman-Tversky Optimization (KTO) loss function. + Paper: "KTO: Model Alignment as Prospect Theory-Guided Optimization" + https://arxiv.org/abs/2402.01306 + + KTO loss is inspired by prospect theory (https://en.wikipedia.org/wiki/Prospect_theory) + from behavioral economics, which models how humans make decisions under uncertainty. + The loss function is asymmetric, treating gains and losses differently, similar to + human decision-making patterns. Formula: + When y is chosen: L_KTO = 1 - σ(β * (log[π(x)/π₀(x)] - KL(π||π₀)_y)) + When y is rejected: + L_KTO = 1 - σ(β * (KL(π||π₀)_y - log[π(x)/π₀(x)])) Where: - σ: Sigmoid function - - β: Temperature parameter - - KL(π||π₀)_y is KL divergence estimated using the rejected response y + - β: Temperature parameter controlling the strength of the preference signal + - π(x): Policy (current model) + - π₀(x): Reference policy (reference model) + - KL(π||π₀)_y: KL divergence estimated using the rejected response y + + The loss encourages the model to: + 1. Assign higher probability to chosen responses + 2. Assign lower probability to rejected responses + 3. Maintain reasonable distance from the reference model Args: chosen_logps: Log probabilities of chosen tokens (batch_size,) @@ -35,6 +52,12 @@ def preference_loss_fn( ref_chosen_logps: Reference log probs of chosen tokens (batch_size,) ref_rejected_logps: Reference log probs of rejected tokens (batch_size,) beta: Weight for the direct preference loss + + Returns: + Tuple of (loss, chosen_rewards, rejected_rewards): + - loss: The KTO loss value + - chosen_rewards: Reward signals for chosen responses (detached) + - rejected_rewards: Reward signals for rejected responses (detached) """ if ref_chosen_logps is None: ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device) @@ -60,7 +83,8 @@ def preference_loss_fn( rejected_rewards = beta * rejected_logratios.detach() return ( - losses.sum() / (full_target.shape[0] // 2), + # We don't divide by 2 because KTO Loss doesn't need pair-wise examples + losses.sum() / (full_target.shape[0]), chosen_rewards, rejected_rewards, )