Skip to content

Commit

Permalink
NPU] Update prompt format for baichuan2-pipeline (#12625)
Browse files Browse the repository at this point in the history
  • Loading branch information
lzivan authored Dec 27, 2024
1 parent 34dbdb8 commit 5f04ed7
Showing 1 changed file with 7 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,6 @@

logger = logging.get_logger(__name__)

def get_prompt(message: str, chat_history: list[tuple[str, str]],
system_prompt: str) -> str:
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
# The first user input is _not_ stripped
do_strip = False
for user_input, response in chat_history:
user_input = user_input.strip() if do_strip else user_input
do_strip = True
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
message = message.strip() if do_strip else message
texts.append(f'{message} [/INST]')
return ''.join(texts)

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Predict Tokens using `generate()` API for npu model"
Expand Down Expand Up @@ -108,11 +95,15 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
with torch.inference_mode():
print("finish to load")
for i in range(3):
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
_input_ids = tokenizer.encode(prompt, return_tensors="pt")
messages = [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": args.prompt}]
text = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
_input_ids = tokenizer([text], return_tensors="pt").input_ids
print("-" * 20, "Input", "-" * 20)
print("input length:", len(_input_ids[0]))
print(prompt)
print(args.prompt)
print("-" * 20, "Output", "-" * 20)
st = time.time()
output = model.generate(
Expand Down

0 comments on commit 5f04ed7

Please sign in to comment.