Skip to content

Commit

Permalink
optimize logit_bias
Browse files Browse the repository at this point in the history
  • Loading branch information
xu-song committed Feb 15, 2025
1 parent 9206b3d commit 40f5641
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
3 changes: 2 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2003,7 +2003,8 @@ def _build_logits_processors(
processors = get_openai_logits_processors(
logit_bias=sampling_params.logit_bias,
allowed_token_ids=sampling_params.allowed_token_ids,
tokenizer=tokenizer)
tokenizer=tokenizer,
dtype=self.model_config.dtype)
logits_processors.extend(processors)

# Unset so these don't get passed down to the model
Expand Down
12 changes: 9 additions & 3 deletions vllm/entrypoints/openai/logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,20 @@ def _get_allowed_token_ids_logits_processor(


def logit_bias_logits_processor(
logit_bias: Dict[int, float],
logit_bias: Dict[str, torch.Tensor],
token_ids: List[int],
logits: torch.Tensor,
) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
logits.index_add_(0, logit_bias["index"].to(logits.device),
logit_bias["value"].to(logits.device))
return logits


def get_logits_processors(
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
allowed_token_ids: Optional[List[int]],
tokenizer: AnyTokenizer,
dtype: Union[str, torch.dtype],
) -> List[LogitsProcessor]:
logits_processors: List[LogitsProcessor] = []
if logit_bias:
Expand All @@ -77,6 +78,11 @@ def get_logits_processors(
raise ValueError(f"token_id {token_id} in logit_bias contains "
"out-of-vocab token id")

clamped_logit_bias = {
"index": torch.tensor(list(clamped_logit_bias.keys())),

Check failure on line 82 in vllm/entrypoints/openai/logits_processors.py

View workflow job for this annotation

GitHub Actions / pre-commit

Dict entry 0 has incompatible type "str": "Any"; expected "int": "float" [dict-item]

Check failure on line 82 in vllm/entrypoints/openai/logits_processors.py

View workflow job for this annotation

GitHub Actions / pre-commit

Dict entry 0 has incompatible type "str": "Any"; expected "int": "float" [dict-item]

Check failure on line 82 in vllm/entrypoints/openai/logits_processors.py

View workflow job for this annotation

GitHub Actions / pre-commit

Dict entry 0 has incompatible type "str": "Any"; expected "int": "float" [dict-item]

Check failure on line 82 in vllm/entrypoints/openai/logits_processors.py

View workflow job for this annotation

GitHub Actions / pre-commit

Dict entry 0 has incompatible type "str": "Any"; expected "int": "float" [dict-item]
"value": torch.tensor(list(clamped_logit_bias.values()),

Check failure on line 83 in vllm/entrypoints/openai/logits_processors.py

View workflow job for this annotation

GitHub Actions / pre-commit

Dict entry 1 has incompatible type "str": "Any"; expected "int": "float" [dict-item]

Check failure on line 83 in vllm/entrypoints/openai/logits_processors.py

View workflow job for this annotation

GitHub Actions / pre-commit

Dict entry 1 has incompatible type "str": "Any"; expected "int": "float" [dict-item]

Check failure on line 83 in vllm/entrypoints/openai/logits_processors.py

View workflow job for this annotation

GitHub Actions / pre-commit

Dict entry 1 has incompatible type "str": "Any"; expected "int": "float" [dict-item]
dtype=dtype)
}
logits_processors.append(
partial(logit_bias_logits_processor, clamped_logit_bias))

Check failure on line 87 in vllm/entrypoints/openai/logits_processors.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "logit_bias_logits_processor" has incompatible type "dict[int, float]"; expected "dict[str, Any]" [arg-type]

Check failure on line 87 in vllm/entrypoints/openai/logits_processors.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "logit_bias_logits_processor" has incompatible type "dict[int, float]"; expected "dict[str, Any]" [arg-type]

Check failure on line 87 in vllm/entrypoints/openai/logits_processors.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "logit_bias_logits_processor" has incompatible type "dict[int, float]"; expected "dict[str, Any]" [arg-type]

Expand Down

0 comments on commit 40f5641

Please sign in to comment.