diff --git a/turbo_alignment/generators/vllm_chat.py b/turbo_alignment/generators/vllm_chat.py index 5a89bb1..8df843b 100755 --- a/turbo_alignment/generators/vllm_chat.py +++ b/turbo_alignment/generators/vllm_chat.py @@ -51,6 +51,7 @@ def __init__( skip_special_tokens=custom_generation_settings.skip_special_tokens, stop_token_ids=eos_token_id, max_tokens=transformers_settings.max_new_tokens, + logprobs=transformers_settings.logprobs, **beam_search_params, ) self._lora_request = lora_request diff --git a/turbo_alignment/settings/tf/generation.py b/turbo_alignment/settings/tf/generation.py index 14537d9..421411d 100755 --- a/turbo_alignment/settings/tf/generation.py +++ b/turbo_alignment/settings/tf/generation.py @@ -11,3 +11,4 @@ class GeneratorTransformersSettings(ExtraFieldsNotAllowedBaseModel): top_k: int = 50 temperature: float = 1.0 stop_strings: str | list[str] = '' + logprobs: int = 1