Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add softcapping to preference based fused linear #437

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
10 changes: 9 additions & 1 deletion src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -47,6 +49,7 @@ def forward(
alpha=1.0,
compute_nll_loss=True,
compiled=True,
softcap=None,
):
return LigerFusedLinearPreferenceBase.forward(
ctx,
Expand All @@ -60,12 +63,13 @@ def forward(
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
softcap=softcap,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None
return *grads, None, None, None, None, None, None


class LigerFusedLinearCPOLoss(torch.nn.Module):
Expand All @@ -80,18 +84,21 @@ def __init__(
alpha: float = 1.0,
compute_nll_loss: bool = True,
compiled: bool = True,
softcap: Optional[float] = None,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.softcap = softcap

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearCPOFunction.apply(
Expand All @@ -104,4 +111,5 @@ def forward(self, lin_weight, _input, target, bias=None):
self.alpha,
self.compute_nll_loss,
self.compiled,
self.softcap,
)
10 changes: 9 additions & 1 deletion src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -67,6 +69,7 @@ def forward(
compute_nll_loss=True,
compiled=True,
use_ref_model=True,
softcap=None,
):
return LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
Expand All @@ -83,12 +86,13 @@ def forward(
ref_input=ref_input,
ref_weight=ref_weight,
ref_bias=ref_bias,
softcap=softcap,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None, None, None
return *grads, None, None, None, None, None, None, None, None, None


class LigerFusedLinearDPOLoss(torch.nn.Module):
Expand All @@ -103,6 +107,7 @@ def __init__(
compute_nll_loss: bool = True,
compiled: bool = True,
use_ref_model: bool = False,
softcap: Optional[float] = None,
):
"""
Args:
Expand All @@ -111,13 +116,15 @@ def __init__(
compute_nll_loss (bool): Whether to compute the NLL loss.
compiled (bool): Whether to use the torch compiled kernel.
use_ref_model (bool): Whether to use a reference model for the DPO loss.
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.use_ref_model = use_ref_model
self.softcap = softcap

def forward(
self,
Expand All @@ -142,4 +149,5 @@ def forward(
self.compute_nll_loss,
self.compiled,
self.use_ref_model,
self.softcap,
)
11 changes: 11 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def forward(
ref_input=None,
ref_weight=None,
ref_bias=None,
softcap=None,
**loss_kwargs,
):
"""
Expand Down Expand Up @@ -61,6 +62,7 @@ def forward(
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
Expand Down Expand Up @@ -282,11 +284,16 @@ def chunk_forward(
bias=None,
ignore_index=-100,
compute_nll_loss=True,
softcap=None,
):
len_chosen_chunk = target_chunk.shape[0] // 2
logits_chunk = input_chunk @ weight.t()
if bias is not None:
logits_chunk = logits_chunk + bias
if softcap is not None:
logits_chunk = logits_chunk / softcap
logits_chunk = torch.tanh(logits_chunk)
logits_chunk = logits_chunk * softcap
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)

chosen_nll_loss = 0.0
Expand Down Expand Up @@ -336,6 +343,7 @@ def _compute_loss(
ref_input_chunk=None,
ref_weight=None,
ref_bias=None,
softcap=None,
**loss_kwargs,
):
"""
Expand All @@ -354,6 +362,7 @@ def _compute_loss(
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
loss_kwargs (dict): Additional arguments for the loss function.
"""
(
Expand All @@ -369,6 +378,7 @@ def _compute_loss(
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
softcap=softcap,
)
chosen_nll_loss = (
chosen_nll_loss
Expand Down Expand Up @@ -396,6 +406,7 @@ def _compute_loss(
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False, # We don't need NLL loss for the reference model
softcap=softcap,
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
Expand Down
10 changes: 9 additions & 1 deletion src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -57,6 +59,7 @@ def forward(
beta=0.1,
compute_nll_loss=True,
compiled=True,
softcap=None,
):
return LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
Expand All @@ -69,12 +72,13 @@ def forward(
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
softcap=softcap,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None
return *grads, None, None, None, None, None


class LigerFusedLinearORPOLoss(torch.nn.Module):
Expand All @@ -88,17 +92,20 @@ def __init__(
beta: float = 0.1,
compute_nll_loss: bool = True,
compiled: bool = True,
softcap: Optional[float] = None,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.softcap = softcap

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearORPOFunction.apply(
Expand All @@ -110,4 +117,5 @@ def forward(self, lin_weight, _input, target, bias=None):
self.beta,
self.compute_nll_loss,
self.compiled,
self.softcap,
)
10 changes: 9 additions & 1 deletion src/liger_kernel/chunked_loss/simpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -51,6 +53,7 @@ def forward(
compute_nll_loss=False,
compiled=True,
gamma=0.5,
softcap=None,
):
return LigerFusedLinearPreferenceBase.forward(
ctx,
Expand All @@ -65,12 +68,13 @@ def forward(
beta=beta,
compiled=compiled,
gamma=gamma,
softcap=softcap,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None
return *grads, None, None, None, None, None, None, None


class LigerFusedLinearSimPOLoss(torch.nn.Module):
Expand All @@ -86,11 +90,13 @@ def __init__(
compute_nll_loss: bool = True,
compiled: bool = True,
gamma: float = 0.5,
softcap: Optional[float] = None,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
"""
super().__init__()
self.ignore_index = ignore_index
Expand All @@ -99,6 +105,7 @@ def __init__(
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.gamma = gamma
self.softcap = softcap

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearSimPOFunction.apply(
Expand All @@ -112,4 +119,5 @@ def forward(self, lin_weight, _input, target, bias=None):
self.compute_nll_loss,
self.compiled,
self.gamma,
self.softcap,
)
Loading