Skip to content

Commit

Permalink
Add vllm decoder for model inference (#124)
Browse files Browse the repository at this point in the history
* add vllm decoder

* Add test for vllm

* Add openbuddy 70b v10.1

* Improve vllm decoder
  • Loading branch information
44670 authored Aug 23, 2023
1 parent 46183df commit 173ce5a
Show file tree
Hide file tree
Showing 7 changed files with 5,012 additions and 52 deletions.
4,832 changes: 4,832 additions & 0 deletions results/openbuddy-llama2-70b-v10.1/model_outputs.json

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions src/alpaca_eval/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ def get_fn_completions(name: Union[str, Callable]) -> Callable:
from .jinachat import jina_chat_completions

return jina_chat_completions

elif name == "vllm_local_completions":
try:
from .vllm_local import vllm_local_completions
return vllm_local_completions
except ImportError as e:
packages = ["vllm", "ray", "transformers"]
logging.exception(f"You need {packages} to use vllm_completions. Error:")
raise e




else:
raise ValueError(f"Unknown decoder: {name}")
82 changes: 82 additions & 0 deletions src/alpaca_eval/decoders/vllm_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import logging
from typing import Optional, Sequence

import numpy as np
import torch
from vllm import LLM, SamplingParams


from .. import constants, utils

__all__ = ["vllm_local_completions"]

llm = None
llmModelName = None

def vllm_local_completions(
prompts: Sequence[str],
model_name: str,
max_new_tokens: int,
do_sample: bool = False,
batch_size: int = 1,
model_kwargs=None,
cache_dir: Optional[str] = constants.DEFAULT_CACHE_DIR,
is_fast_tokenizer: bool = True,
adapters_name: Optional[str] = None,
**kwargs,
) -> dict[str, list]:
"""Decode locally using vllm transformers pipeline.
Parameters
----------
prompts : list of str
Prompts to get completions for.
model_name : str, optional
Name of the model (repo on hugging face hub) to use for decoding.
do_sample : bool, optional
Whether to use sampling for decoding.
batch_size : int, optional
Batch size to use for decoding. This currently does not work well with to_bettertransformer.
model_kwargs : dict, optional
Additional kwargs to pass to from_pretrained.
cache_dir : str, optional
Directory to use for caching the model.
kwargs :
Additional kwargs to pass to `InferenceApi.__call__`.
"""
global llm, llmModelName
tp = 1
if 'tp' in model_kwargs:
tp = model_kwargs['tp']
if llm is None:
logging.info("vllm: loading model: %s, tp=%d", model_name, tp)
llm = LLM(model=model_name, tokenizer=model_name, tensor_parallel_size=tp)
llmModelName = model_name
if model_name != llmModelName:
assert False, "vllm_local_completions can only be used with a single model"

sampling_params = SamplingParams(max_tokens=max_new_tokens)
if 'temperature' in kwargs:
sampling_params.temperature = kwargs['temperature']
if 'top_p' in kwargs:
sampling_params.top_p = kwargs['top_p']
if 'top_k' in kwargs:
sampling_params.top_k = kwargs['top_k']
if do_sample:
sampling_params.use_beam_search = True
completions = []
with utils.Timer() as t:
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i + batch_size]
outputs = llm.generate(batch, sampling_params)
for j in range(0, len(batch)):
completions.append(outputs[j].outputs[0].text)
price = [np.nan] * len(completions)
avg_time = [t.duration / len(prompts)] * len(completions)
return dict(completions=completions, price_per_example=price, time_per_example=avg_time)
Original file line number Diff line number Diff line change
@@ -1,55 +1,56 @@
,win_rate,standard_error,n_wins,n_wins_base,n_draws,n_total,mode,avg_length
gpt4,95.27950310559004,0.716281440286153,761,32,12,805,minimal,1365
llama-2-70b-chat-hf,92.66169154228857,0.911762258320568,743,57,4,804,minimal,1790
claude-2,91.35572139303484,0.9897323784630048,734,69,1,804,minimal,1069
openchat-v3.1-13b,89.49004975124379,1.076875474505156,718,83,3,804,community,1484
chatgpt,89.36567164179104,1.0789487022114888,716,83,5,804,minimal,827
wizardlm-13b-v1.2,89.1656288916563,1.0904254662753898,714,85,4,803,community,1635
vicuna-33b-v1.3,88.99253731343283,1.095692216068168,713,86,5,804,verified,1479
claude,88.38509316770187,1.1144875403283188,707,89,9,805,minimal,1082
gpt4,95.27950310559004,0.716281440286153,761,32,12,805,minimal,1365.0
llama-2-70b-chat-hf,92.66169154228857,0.911762258320568,743,57,4,804,minimal,1790.0
claude-2,91.35572139303484,0.9897323784630048,734,69,1,804,minimal,1069.0
openchat-v3.1-13b,89.49004975124379,1.076875474505156,718,83,3,804,community,1484.0
chatgpt,89.36567164179104,1.0789487022114888,716,83,5,804,minimal,827.0
wizardlm-13b-v1.2,89.1656288916563,1.0904254662753898,714,85,4,803,community,1635.0
vicuna-33b-v1.3,88.99253731343283,1.095692216068168,713,86,5,804,verified,1479.0
claude,88.38509316770187,1.1144875403283188,707,89,9,805,minimal,1082.0
humpback-llama2-70b,87.93532338308458,1.1545476754393662,706,96,2,804,community,1822.0
openchat-v2-w-13b,87.1268656716418,1.1769197439396015,699,102,3,804,community,1566
openbuddy-llama-65b-v8,86.53366583541147,1.2029182403474274,693,107,2,802,community,1162
wizardlm-13b-v1.1,86.31840796019901,1.2063217831272972,692,108,4,804,community,1525
openchat-v2-13b,84.96894409937889,1.2572979835605944,683,120,2,805,community,1564
openbuddy-llama2-70b-v10.1,87.67123287671232,1.1508417516577765,701,96,6,803,community,1077.0
openchat-v2-w-13b,87.1268656716418,1.1769197439396015,699,102,3,804,community,1566.0
openbuddy-llama-65b-v8,86.53366583541147,1.2029182403474274,693,107,2,802,community,1162.0
wizardlm-13b-v1.1,86.31840796019901,1.2063217831272972,692,108,4,804,community,1525.0
openchat-v2-13b,84.96894409937889,1.2572979835605944,683,120,2,805,community,1564.0
humpback-llama-65b,83.70646766169155,1.3071034735987248,672,130,2,804,community,1269.0
vicuna-13b-v1.3,82.11180124223603,1.348769957803504,660,143,2,805,verified,1132
openbuddy-llama-30b-v7.1,81.54613466334165,1.370658000946423,654,148,0,802,community,968
llama-2-13b-chat-hf,81.09452736318407,1.3817573087734825,652,152,0,804,minimal,1513
openchat-13b,80.8695652173913,1.3843738653129234,650,153,2,805,community,1632
openbuddy-falcon-40b-v9,80.69738480697384,1.3908517976873223,647,154,2,803,community,1089
ultralm-13b,80.63511830635119,1.3939556917204066,647,155,1,803,community,1087
openchat8192-13b,79.53980099502488,1.4222439886269744,639,164,1,804,community,1664
opencoderplus-15b,78.69565217391305,1.440029529188432,632,170,3,805,community,1628
vicuna-7b-v1.3,76.8414481897628,1.487520320531845,614,184,3,801,verified,1110
wizardlm-13b,75.31094527363184,1.5101858292160824,601,194,9,804,minimal,985
jina-chat,74.12718204488779,1.541070307435577,592,205,5,802,community,676
airoboros-65b,73.91304347826086,1.5285333061227804,587,202,16,805,community,1512
airoboros-33b,73.29192546583852,1.55290318216736,587,212,6,805,community,1514
guanaco-65b,71.80124223602485,1.586912361158523,578,227,0,805,minimal,1249
llama-2-7b-chat-hf,71.36645962732919,1.593038654706019,574,230,1,805,minimal,1479
vicuna-13b,70.43478260869566,1.6069688407799696,566,237,2,805,minimal,1037
openbuddy-falcon-7b-v6,70.36114570361146,1.612538056786233,565,238,0,803,community,1152
baize-v2-13b,66.95652173913044,1.6565358231309506,538,265,2,805,community,930
oasst-rlhf-llama-33b,66.52173913043478,1.6608288428292477,534,268,3,805,minimal,1079
minotaur-13b,66.02484472049689,1.6645545328264226,529,271,5,805,community,881
guanaco-33b,65.96273291925466,1.67108537053247,531,274,0,805,verified,1311
nous-hermes-13b,65.46583850931677,1.669962276077284,524,275,6,805,verified,844
vicuna-7b,64.40993788819875,1.6851107260487883,517,285,3,805,verified,1044
baize-v2-7b,63.85093167701863,1.6945981855442178,514,291,0,805,community,1127
oasst-sft-llama-33b,54.96894409937888,1.7402667933686875,436,356,13,805,verified,748
guanaco-13b,52.60869565217391,1.7576690299699242,422,380,3,805,verified,1774
text_davinci_003,50.0,0.0,0,0,805,805,minimal,307
chatglm2-6b,47.12858926342072,1.7593143221324448,375,421,5,801,community,1027
guanaco-7b,46.58385093167702,1.7570464905413992,374,429,2,805,verified,1364
falcon-40b-instruct,45.71428571428572,1.7524717060805597,366,435,4,805,minimal,662
alpaca-farm-ppo-sim-gpt4-20k,44.099378881987576,1.7399772578861137,350,445,10,805,verified,511
pythia-12b-mix-sft,41.86335403726708,1.737637146007538,336,467,2,805,verified,913
alpaca-farm-ppo-human,41.24223602484472,1.7271813123250834,328,469,8,805,minimal,803
cohere-chat,29.565217391304348,1.5949050483247118,232,561,12,805,community,779
cohere,28.385093167701864,1.5717547121761728,221,569,15,805,community,682
alpaca-7b,26.459627329192543,1.535711469748,205,584,16,805,minimal,396
oasst-sft-pythia-12b,25.962732919254663,1.5261079289535309,201,588,16,805,verified,726
falcon-7b-instruct,23.60248447204969,1.4898235369056625,187,612,6,805,verified,478
baichuan-13b-chat,21.801242236024844,1.4495247592518703,173,627,5,805,community,1727
text_davinci_001,15.17412935323383,1.235107892276849,112,672,20,804,minimal,296
vicuna-13b-v1.3,82.11180124223603,1.348769957803504,660,143,2,805,verified,1132.0
openbuddy-llama-30b-v7.1,81.54613466334165,1.370658000946423,654,148,0,802,community,968.0
llama-2-13b-chat-hf,81.09452736318407,1.3817573087734825,652,152,0,804,minimal,1513.0
openchat-13b,80.8695652173913,1.3843738653129234,650,153,2,805,community,1632.0
openbuddy-falcon-40b-v9,80.69738480697384,1.3908517976873225,647,154,2,803,community,1089.0
ultralm-13b,80.63511830635119,1.3939556917204066,647,155,1,803,community,1087.0
openchat8192-13b,79.53980099502488,1.4222439886269744,639,164,1,804,community,1664.0
opencoderplus-15b,78.69565217391305,1.440029529188432,632,170,3,805,community,1628.0
vicuna-7b-v1.3,76.8414481897628,1.487520320531845,614,184,3,801,verified,1110.0
wizardlm-13b,75.31094527363184,1.5101858292160824,601,194,9,804,minimal,985.0
jina-chat,74.12718204488779,1.541070307435577,592,205,5,802,community,676.0
airoboros-65b,73.91304347826086,1.5285333061227804,587,202,16,805,community,1512.0
airoboros-33b,73.29192546583852,1.55290318216736,587,212,6,805,community,1514.0
guanaco-65b,71.80124223602485,1.586912361158523,578,227,0,805,minimal,1249.0
llama-2-7b-chat-hf,71.36645962732919,1.593038654706019,574,230,1,805,minimal,1479.0
vicuna-13b,70.43478260869566,1.6069688407799696,566,237,2,805,minimal,1037.0
openbuddy-falcon-7b-v6,70.36114570361146,1.612538056786233,565,238,0,803,community,1152.0
baize-v2-13b,66.95652173913044,1.6565358231309506,538,265,2,805,community,930.0
oasst-rlhf-llama-33b,66.52173913043478,1.6608288428292477,534,268,3,805,minimal,1079.0
minotaur-13b,66.02484472049689,1.6645545328264226,529,271,5,805,community,881.0
guanaco-33b,65.96273291925466,1.67108537053247,531,274,0,805,verified,1311.0
nous-hermes-13b,65.46583850931677,1.669962276077284,524,275,6,805,verified,844.0
vicuna-7b,64.40993788819875,1.6851107260487883,517,285,3,805,verified,1044.0
baize-v2-7b,63.85093167701863,1.6945981855442178,514,291,0,805,community,1127.0
oasst-sft-llama-33b,54.96894409937888,1.7402667933686875,436,356,13,805,verified,748.0
guanaco-13b,52.60869565217391,1.7576690299699242,422,380,3,805,verified,1774.0
text_davinci_003,50.0,0.0,0,0,805,805,minimal,307.0
chatglm2-6b,47.12858926342072,1.7593143221324448,375,421,5,801,community,1027.0
guanaco-7b,46.58385093167702,1.7570464905413992,374,429,2,805,verified,1364.0
falcon-40b-instruct,45.71428571428572,1.7524717060805597,366,435,4,805,minimal,662.0
alpaca-farm-ppo-sim-gpt4-20k,44.099378881987576,1.7399772578861137,350,445,10,805,verified,511.0
pythia-12b-mix-sft,41.86335403726708,1.737637146007538,336,467,2,805,verified,913.0
alpaca-farm-ppo-human,41.24223602484472,1.7271813123250834,328,469,8,805,minimal,803.0
cohere-chat,29.565217391304348,1.5949050483247118,232,561,12,805,community,779.0
cohere,28.385093167701864,1.5717547121761728,221,569,15,805,community,682.0
alpaca-7b,26.459627329192543,1.535711469748,205,584,16,805,minimal,396.0
oasst-sft-pythia-12b,25.962732919254663,1.5261079289535309,201,588,16,805,verified,726.0
falcon-7b-instruct,23.60248447204969,1.4898235369056625,187,612,6,805,verified,478.0
baichuan-13b-chat,21.801242236024844,1.4495247592518703,173,627,5,805,community,1727.0
text_davinci_001,15.17412935323383,1.235107892276849,112,672,20,804,minimal,296.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
openbuddy-llama2-70b-v10.1:
prompt_template: "openbuddy-llama2-70b-v10.1/prompt.txt"
fn_completions: "vllm_local_completions"
completions_kwargs:
model_name: "./openbuddy-llama2-70b-v10.1-bf16" # local path
model_kwargs:
torch_dtype: 'bfloat16'
tp: 4 # Tensor parallelism, 4 for 4 GPUs
max_new_tokens: 2000 # on vllm, prompt_tokens+max_new_tokens must be <= model_max_length otherwise it will raise an error and nothing will be generated
temperature: 0.7
top_p: 1.0
do_sample: False
batch_size: 16
pretty_name: "OpenBudddy-LLaMA2-70b-v10.1"
link: "https://huggingface.co/OpenBuddyEA"
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User.
Always answer as helpfully and logically as possible, while being safe. Please ensure that your responses are socially unbiased and positive, detailed, trustworthy in nature.
You can speak fluently in English.
You have vast knowledge. Your answer should be creative, interesting, long, detailed, but never repetitive.
Never refuse to answer a question or ask questions to the User.

User: Hi.
Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?
User: {instruction}
Assistant:
8 changes: 8 additions & 0 deletions tests/integration_tests/test_decoders_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,11 @@ def test_huggingface_local_completions_integration():
results = huggingface_local_completions(prompts, model_name="hf-internal-testing/tiny-random-gpt2")
assert len(results["completions"]) == len(prompts)
# nothing to test because random model


@pytest.mark.slow
def test_vllm_local_completions_integration():
from alpaca_eval.decoders.vllm_local import vllm_local_completions
prompts = _get_formatted_prompts("text_davinci_003") # nor formatting
results = vllm_local_completions(prompts, model_name="OpenBuddy/openbuddy-openllama-3b-v10-bf16", max_new_tokens=100)
assert len(results["completions"]) == len(prompts)

0 comments on commit 173ce5a

Please sign in to comment.