Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Temperature Scaling in Distillation Loss #444

Merged
merged 2 commits into from
Jan 1, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@

class LigerFusedLinearDistillationBase(torch.autograd.Function):
@abstractmethod
def distillation_loss_fn(student_logits, teacher_logits, temperature):
def distillation_loss_fn(
student_logits,
teacher_logits,
):
"""
Compute distillation loss.
Args:
student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
"""
raise NotImplementedError("Distillation loss function must be implemented.")

Expand Down Expand Up @@ -65,7 +68,6 @@ def _compute_loss(
distillation_loss_fn=None,
full_target=None,
ignore_index=-100,
temperature=1.0,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
compute_ce_loss=True,
Expand Down Expand Up @@ -107,7 +109,7 @@ def _compute_loss(

hard_loss /= full_target.shape[0]

soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
soft_loss /= full_target.shape[0]

loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
Expand Down Expand Up @@ -147,10 +149,11 @@ def forward(
teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
chunk_size (int): Size of a chunk.
compute_ce_loss (bool): Whether to compute CE loss.
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard/task loss.
weight_soft_loss (float): Weight for soft/distillation loss.
compute_ce_loss (bool): Whether to compute CE loss.
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
compiled (bool): Whether to use torch compile for chunk accumulation.
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
Expand All @@ -168,7 +171,6 @@ def forward(
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
**loss_kwargs,
)

Expand Down Expand Up @@ -223,6 +225,9 @@ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)

student_input /= temperature
teacher_input /= temperature

num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
Expand Down
Loading