Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
YannDubs committed Aug 23, 2023
1 parent 1eea6f1 commit 72f1e56
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 22 deletions.
33 changes: 13 additions & 20 deletions src/alpaca_eval/decoders/vllm_local.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import logging
from typing import Optional, Sequence
from typing import Sequence

import numpy as np
import torch
from vllm import LLM, SamplingParams


from .. import constants, utils
from .. import utils

__all__ = ["vllm_local_completions"]

llm = None
llmModelName = None


def vllm_local_completions(
prompts: Sequence[str],
model_name: str,
max_new_tokens: int,
do_sample: bool = False,
batch_size: int = 1,
model_kwargs=None,
cache_dir: Optional[str] = constants.DEFAULT_CACHE_DIR,
is_fast_tokenizer: bool = True,
adapters_name: Optional[str] = None,
**kwargs,
) -> dict[str, list]:
"""Decode locally using vllm transformers pipeline.
Expand All @@ -44,16 +40,13 @@ def vllm_local_completions(
model_kwargs : dict, optional
Additional kwargs to pass to from_pretrained.
cache_dir : str, optional
Directory to use for caching the model.
kwargs :
Additional kwargs to pass to `InferenceApi.__call__`.
"""
global llm, llmModelName
tp = 1
if 'tp' in model_kwargs:
tp = model_kwargs['tp']
if "tp" in model_kwargs:
tp = model_kwargs["tp"]
if llm is None:
logging.info("vllm: loading model: %s, tp=%d", model_name, tp)
llm = LLM(model=model_name, tokenizer=model_name, tensor_parallel_size=tp)
Expand All @@ -62,21 +55,21 @@ def vllm_local_completions(
assert False, "vllm_local_completions can only be used with a single model"

sampling_params = SamplingParams(max_tokens=max_new_tokens)
if 'temperature' in kwargs:
sampling_params.temperature = kwargs['temperature']
if 'top_p' in kwargs:
sampling_params.top_p = kwargs['top_p']
if 'top_k' in kwargs:
sampling_params.top_k = kwargs['top_k']
if "temperature" in kwargs:
sampling_params.temperature = kwargs["temperature"]
if "top_p" in kwargs:
sampling_params.top_p = kwargs["top_p"]
if "top_k" in kwargs:
sampling_params.top_k = kwargs["top_k"]
if do_sample:
sampling_params.use_beam_search = True
completions = []
with utils.Timer() as t:
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i + batch_size]
batch = prompts[i : i + batch_size]
outputs = llm.generate(batch, sampling_params)
for j in range(0, len(batch)):
completions.append(outputs[j].outputs[0].text)
price = [np.nan] * len(completions)
avg_time = [t.duration / len(prompts)] * len(completions)
return dict(completions=completions, price_per_example=price, time_per_example=avg_time)
return dict(completions=completions, price_per_example=price, time_per_example=avg_time)
7 changes: 5 additions & 2 deletions tests/integration_tests/test_decoders_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def test_huggingface_local_completions_integration():
@pytest.mark.slow
def test_vllm_local_completions_integration():
from alpaca_eval.decoders.vllm_local import vllm_local_completions

prompts = _get_formatted_prompts("text_davinci_003") # nor formatting
results = vllm_local_completions(prompts, model_name="OpenBuddy/openbuddy-openllama-3b-v10-bf16", max_new_tokens=100)
assert len(results["completions"]) == len(prompts)
results = vllm_local_completions(
prompts, model_name="OpenBuddy/openbuddy-openllama-3b-v10-bf16", max_new_tokens=100
)
assert len(results["completions"]) == len(prompts)

0 comments on commit 72f1e56

Please sign in to comment.