Skip to content

Commit

Permalink
Use torch.compile for scaling penalty (sgl-project#3133)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jan 26, 2025
1 parent da6f808 commit 27acf63
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import itertools
import time

import torch
import triton
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
import torch

from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
from sglang.srt.utils import is_cuda_available
from sglang.srt.utils import get_compiler_backend

is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import sampling_scaling_penalties

@torch.compile(dynamic=True, backend=get_compiler_backend())
def apply_scaling_penalties(logits, scaling_penalties):
logits[:] = torch.where(
logits > 0,
logits / scaling_penalties,
logits * scaling_penalties,
)


class BatchedRepetitionPenalizer(_BatchedPenalizer):
Expand Down Expand Up @@ -61,16 +66,7 @@ def _cumulate_output_tokens(self, output_ids: _TokenIDs):
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]

def _apply(self, logits: torch.Tensor) -> torch.Tensor:
if is_cuda:
return sampling_scaling_penalties(
logits, self.cumulated_repetition_penalties
)
else:
return torch.where(
logits > 0,
logits / self.cumulated_repetition_penalties,
logits * self.cumulated_repetition_penalties,
)
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)

def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
Expand Down
18 changes: 4 additions & 14 deletions python/sglang/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@

import torch

from sglang.srt.utils import is_cuda_available

is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import sampling_scaling_penalties

import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
apply_scaling_penalties,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -386,14 +383,7 @@ def apply_logits_bias(self, logits: torch.Tensor):

# repetition
if self.scaling_penalties is not None:
if is_cuda:
logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
else:
logits[:] = torch.where(
logits > 0,
logits / self.scaling_penalties,
logits * self.scaling_penalties,
)
apply_scaling_penalties(logits, self.scaling_penalties)

# Apply regex vocab_mask
if self.vocab_mask is not None:
Expand Down

0 comments on commit 27acf63

Please sign in to comment.