Skip to content

Commit

Permalink
non batch greedy decoing.
Browse files Browse the repository at this point in the history
  • Loading branch information
shamanez committed Oct 28, 2024
1 parent 67d270c commit b459177
Showing 1 changed file with 66 additions and 36 deletions.
102 changes: 66 additions & 36 deletions lm_eval/models/vllm_causallms.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

@register_model("vllm")
class VLLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048
_DEFAULT_MAX_LENGTH = 4096

def __init__(
self,
Expand All @@ -62,12 +62,12 @@ def __init__(
device: str = "cuda",
data_parallel_size: int = 1,
lora_local_path: str = None,
use_blueberry: bool = True,
use_blueberry: bool = False,
**kwargs,
):
super().__init__()

self.use_blueberry = use_blueberry
self.use_blueberry = use_blueberry

if not find_spec("vllm"):
raise ModuleNotFoundError(
Expand Down Expand Up @@ -229,24 +229,7 @@ def _model_generate(
stop: Optional[List[str]] = None,
**kwargs,
):

if self.use_blueberry:

outputs = []
for req in requests:

output = self.generate_blueberry(
req,
system_prompt="you are a helpful assistant",
use_tqdm=True if self.batch_size == "auto" else False,
)

outputs.append(output)

return outputs




if generate:
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
Expand Down Expand Up @@ -289,11 +272,50 @@ def run_inference_one_model(
lora_request=self.lora_request,
)
else:
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
)

if self.use_blueberry:
outputs = []
for req in requests:

output = self.generate_blueberry(
req,
system_prompt="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
use_tqdm=True if self.batch_size == "auto" else False,
)

outputs.append(output)
else:

sampling_params = SamplingParams(temperature=1, min_p=0.05, max_tokens=2048)

outputs = []
for req in requests:

system_prompt = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
# Decode the token IDs into the prompt text
prompt = self.tokenizer.decode(req, skip_special_tokens=True)

# Apply the chat template to the prompt text
prompt_with_template = self.tokenizer.apply_chat_template([
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
], tokenize=False, add_generation_prompt=True)


# print("==============")

# print(prompt_with_template)


# exit("====")

output = self.model.generate(
prompt_with_template,
sampling_params=sampling_params,
#use_tqdm=True if self.batch_size == "auto" else False,
)

outputs.append(output[0].outputs[0].text)
return outputs

def loglikelihood_rolling(
Expand Down Expand Up @@ -425,13 +447,18 @@ def _collate_gen(_requests):

if self.use_blueberry:
generated_text = output
res.append(generated_text)
else:
generated_text = output.outputs[0].text

self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text
)
res.append(generated_text)
#generated_text = output.outputs[0].text
generated_text = output
res.append(generated_text)

# print(res)
# exit()

# self.cache_hook.add_partial(
# "generate_until", (context, gen_kwargs), generated_text
# )
pbar.update(1)

pbar.close()
Expand Down Expand Up @@ -600,23 +627,26 @@ def generate_blueberry(
self,
prompt_token_ids,
system_prompt: str,
use_tqdm=False,
use_tqdm=True,
alpha=0.1,
max_length=2048,
temperature=1,
beam_width=30,
k_steps=5
):
system_prompt = "You are a helpful assistant"
system_prompt = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
# Decode the token IDs into the prompt text
prompt = self.tokenizer.decode(prompt_token_ids, skip_special_tokens=False)
prompt = self.tokenizer.decode(prompt_token_ids, skip_special_tokens=True)

# Apply the chat template to the prompt text
prompt_with_template = self.tokenizer.apply_chat_template([
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
], tokenize=False, add_generation_prompt=True)

# print(prompt_with_template)
# exit("***********")

# Encode the prompt with template back into token IDs
input_ids = self.tokenizer.encode(prompt_with_template)

Expand Down Expand Up @@ -645,7 +675,7 @@ def generate_blueberry(
outputs = self.model.generate(
leaf_prompts,
sampling_params=sampling_params,
use_tqdm=False
use_tqdm=use_tqdm
)

for leaf, output in zip(leaves, outputs):
Expand Down

0 comments on commit b459177

Please sign in to comment.