From ccc405505871985264770f54113f86961d688b7b Mon Sep 17 00:00:00 2001 From: Zijie Li Date: Thu, 26 Dec 2024 11:41:37 +0800 Subject: [PATCH] [NPU] Update prompt format for baichuan2 (#12615) * Update baichuan2.py * style fix --- .../LLM/baichuan2.py | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py index 3aae11f8966..92b29506340 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py @@ -26,19 +26,6 @@ logger = logging.get_logger(__name__) -def get_prompt(message: str, chat_history: list[tuple[str, str]], - system_prompt: str) -> str: - texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] - # The first user input is _not_ stripped - do_strip = False - for user_input, response in chat_history: - user_input = user_input.strip() if do_strip else user_input - do_strip = True - texts.append(f'{user_input} [/INST] {response.strip()} [INST] ') - message = message.strip() if do_strip else message - texts.append(f'{message} [/INST]') - return ''.join(texts) - if __name__ == "__main__": parser = argparse.ArgumentParser( description="Predict Tokens using `generate()` API for npu model" @@ -108,11 +95,15 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]], with torch.inference_mode(): print("finish to load") for i in range(5): - prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) - _input_ids = tokenizer.encode(prompt, return_tensors="pt") + messages = [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": args.prompt}] + text = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + _input_ids = tokenizer([text], return_tensors="pt").input_ids print("-" * 20, "Input", "-" * 20) print("input length:", len(_input_ids[0])) - print(prompt) + print(args.prompt) print("-" * 20, "Output", "-" * 20) st = time.time() output = model.generate(