diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md index 42e72cc5bea..381546d59b5 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md +++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md @@ -37,6 +37,7 @@ bash run_llama2_13b_arc_2_card.sh ##### [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) ```log Inference time: xxxx s +First token cost xxxx s and rest tokens cost average xxxx s -------------------- Prompt -------------------- Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun -------------------- Output -------------------- diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py b/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py index 7e7736d9b8d..5104c7010f0 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py +++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py @@ -69,6 +69,7 @@ if local_rank == args.gpu_num - 1: output_str = tokenizer.decode(output[0], skip_special_tokens=True) print(f'Inference time: {end-st} s') + print(f"First token cost {model.first_token_time:.4f} s and rest tokens cost average {model.rest_cost_mean:.4f} s") print('-'*20, 'Prompt', '-'*20) print(args.prompt) print('-'*20, 'Output', '-'*20) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index a81f0abc979..d750cc1b7e0 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -21,6 +21,8 @@ from torch import nn import torch.distributed as dist import os +import time +import numpy as np from typing import Callable, List, Optional from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList @@ -142,14 +144,23 @@ def pipeline_parallel_generate(self, pre_rank = (local_rank - 1) % self.pipeline_parallel_stages next_rank = (local_rank + 1) % self.pipeline_parallel_stages + self.first_token_time = 0 + self.next_token_time = [] + _input_ids = None _past_key_values = None bs = inputs.shape[0] output_ids = inputs.clone() - for i in range(max_new_tokens): + + step = 0 + while True: + if step >= max_new_tokens: + break + if _input_ids is None: _input_ids = inputs + tic = time.time() if local_rank == 0: outputs = self(input_ids=_input_ids, inputs_embeds=None, past_key_values=_past_key_values, use_cache=True) @@ -172,4 +183,13 @@ def pipeline_parallel_generate(self, _input_ids = next_ids output_ids = torch.cat([output_ids, next_ids], dim=-1) _past_key_values = outputs.past_key_values + toc = time.time() + if step == 0: + self.first_token_time = toc - tic + else: + self.next_token_time.append(toc - tic) + step += 1 + if self.device.type == 'xpu': + torch.xpu.synchronize() + self.rest_cost_mean = np.mean(self.next_token_time) return output_ids