Skip to content

Commit

Permalink
Add support for logit_bias outside of server api. Closes abetlen#827
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen committed Nov 21, 2023
1 parent c21edb6 commit 07e47f5
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 38 deletions.
25 changes: 25 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,7 @@ def _create_completion(
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
logit_bias: Optional[Dict[int, float]] = None,
) -> Union[
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
]:
Expand Down Expand Up @@ -1355,6 +1356,28 @@ def _create_completion(
)
model_name: str = model if model is not None else self.model_path

# NOTE: This likely doesn't work correctly for the first token in the prompt
# because of the extra space added to the start of the prompt_tokens
if logit_bias is not None:
logit_bias_map = {int(k): float(v) for k, v in logit_bias.items()}

def logit_bias_processor(
input_ids: npt.NDArray[np.intc],
scores: npt.NDArray[np.single],
) -> npt.NDArray[np.single]:
new_scores = np.copy(
scores
) # Does it make sense to copy the whole array or can we just overwrite the original one?
for input_id, score in logit_bias_map.items():
new_scores[input_id] = score + scores[input_id]
return new_scores

_logit_bias_processor = LogitsProcessorList([logit_bias_processor])
if logits_processor is None:
logits_processor = _logit_bias_processor
else:
logits_processor = logits_processor.extend(_logit_bias_processor)

if self.verbose:
self._ctx.reset_timings()

Expand Down Expand Up @@ -1963,6 +1986,7 @@ def create_chat_completion(
model: Optional[str] = None,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
) -> Union[
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
]:
Expand Down Expand Up @@ -2011,6 +2035,7 @@ def create_chat_completion(
model=model,
logits_processor=logits_processor,
grammar=grammar,
logit_bias=logit_bias,
)

def __getstate__(self):
Expand Down
3 changes: 3 additions & 0 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __call__(
model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.CreateChatCompletionResponse,
Expand Down Expand Up @@ -308,6 +309,7 @@ def basic_create_chat_completion(
model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.CreateChatCompletionResponse,
Expand Down Expand Up @@ -350,6 +352,7 @@ def basic_create_chat_completion(
model=model,
logits_processor=logits_processor,
grammar=grammar,
logit_bias=logit_bias,
)
return _convert_completion_to_chat(completion_or_chunks, stream=stream)

Expand Down
54 changes: 16 additions & 38 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,36 +646,16 @@ class CreateCompletionRequest(BaseModel):
}


def make_logit_bias_processor(
def _logit_bias_tokens_to_input_ids(
llama: llama_cpp.Llama,
logit_bias: Dict[str, float],
logit_bias_type: Optional[Literal["input_ids", "tokens"]],
):
if logit_bias_type is None:
logit_bias_type = "input_ids"

to_bias: Dict[int, float] = {}
if logit_bias_type == "input_ids":
for input_id, score in logit_bias.items():
input_id = int(input_id)
to_bias[input_id] = score

elif logit_bias_type == "tokens":
for token, score in logit_bias.items():
token = token.encode("utf-8")
for input_id in llama.tokenize(token, add_bos=False, special=True):
to_bias[input_id] = score

def logit_bias_processor(
input_ids: npt.NDArray[np.intc],
scores: npt.NDArray[np.single],
) -> npt.NDArray[np.single]:
new_scores = np.copy(scores) # Does it make sense to copy the whole array or can we just overwrite the original one?
for input_id, score in to_bias.items():
new_scores[input_id] = score + scores[input_id]
return new_scores

return logit_bias_processor
) -> Dict[str, float]:
to_bias: Dict[str, float] = {}
for token, score in logit_bias.items():
token = token.encode("utf-8")
for input_id in llama.tokenize(token, add_bos=False, special=True):
to_bias[str(input_id)] = score
return to_bias


@router.post(
Expand All @@ -694,17 +674,16 @@ async def create_completion(
exclude = {
"n",
"best_of",
"logit_bias",
"logit_bias_type",
"user",
}
kwargs = body.model_dump(exclude=exclude)

if body.logit_bias is not None:
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
[
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
]
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)

if body.grammar is not None:
Expand Down Expand Up @@ -851,17 +830,16 @@ async def create_chat_completion(
) -> llama_cpp.ChatCompletion:
exclude = {
"n",
"logit_bias",
"logit_bias_type",
"user",
}
kwargs = body.model_dump(exclude=exclude)

if body.logit_bias is not None:
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
[
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
]
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)

if body.grammar is not None:
Expand Down

0 comments on commit 07e47f5

Please sign in to comment.