From 1baa3efe0e5a61ea64b1ef1a949adcde2d29adee Mon Sep 17 00:00:00 2001 From: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com> Date: Fri, 2 Aug 2024 12:06:59 +0800 Subject: [PATCH] Optimizations for Pipeline Parallel Serving (#11702) Optimizations for Pipeline Parallel Serving --- .../ipex_llm/serving/fastapi/api_server.py | 2 +- .../transformers/pipeline_parallel.py | 139 +++++++++++------- 2 files changed, 86 insertions(+), 55 deletions(-) diff --git a/python/llm/src/ipex_llm/serving/fastapi/api_server.py b/python/llm/src/ipex_llm/serving/fastapi/api_server.py index 5109c822a49..0cc12bf35e9 100644 --- a/python/llm/src/ipex_llm/serving/fastapi/api_server.py +++ b/python/llm/src/ipex_llm/serving/fastapi/api_server.py @@ -250,7 +250,7 @@ async def generate_stream(inputs_request: InputsRequest): request_id = str(uuid.uuid4()) + "stream" await local_model.waiting_requests.put((request_id, inputs_request)) while True: - await asyncio.sleep(0) + await asyncio.sleep(0.1) cur_streamer = local_model.streamer.get(request_id, None) if cur_streamer is not None: if inputs_request.req_type == 'completion': diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index de51cbca5ca..8202b3ee0ee 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -25,7 +25,7 @@ import os import time import numpy as np -from typing import Callable, List, Optional, Union, Tuple +from typing import Callable, List, Optional, Union, Tuple, Any from types import SimpleNamespace import transformers from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList @@ -37,6 +37,7 @@ import asyncio import uuid import threading +import pickle try: from pydantic import BaseModel except ImportError: @@ -513,6 +514,8 @@ def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, max_pref self.max_prefilled_seqs = max_prefilled_seqs self.partial_output_dict = {} + self.stream_tasks = {} + def load_model(self, model_path, world_size, low_bit='sym_int4'): from ipex_llm.transformers import AutoModelForCausalLM, AutoModel try: @@ -683,16 +686,14 @@ def model_step(self, input, cur_batch): _output = output[0] if _output.dtype != self.dtype: _output = _output.to(self.dtype) - return _output, cur_batch else: if cur_batch.partial_prefilling > 0 and \ cur_batch.prefilled_index == cur_batch.batch_size: _output = self.partial_output_dict.pop(cur_id, None) cur_batch.partial_prefilling = 0 - return _output, cur_batch else: _output = torch.argmax(output.logits[:, -1:, :], dim=-1) - return _output, cur_batch + return _output, cur_batch def is_initialized(self): return True @@ -738,14 +739,71 @@ def clear_batch(self, cur_id): self.is_finish.pop(cur_id, None) self.partial_output_dict.pop(cur_id, None) + async def wait_stream_output(self, cur_id): + cur_task = self.stream_tasks.pop(cur_id, None) + if cur_task is not None: + await cur_task + + def get_printable_text(self, cur_text, request_id): + if cur_text.endswith("\n"): + printable_text = cur_text[self.print_len[request_id]:] + self.token_cache[request_id] = [] + self.print_len[request_id] = 0 + elif len(cur_text) > 0 and _is_chinese_char(ord(cur_text[-1])): + printable_text = cur_text[self.print_len[request_id]:] + self.print_len[request_id] += len(printable_text) + self.token_cache[request_id] = [] + self.print_len[request_id] = 0 + else: + r_index = cur_text.rfind(" ") + 1 + if r_index > self.print_len[request_id]: + printable_text = cur_text[self.print_len[request_id]: r_index] + self.token_cache[request_id] = self.token_cache[request_id][-1:] + self.print_len[request_id] = 0 + else: + printable_text = cur_text[self.print_len[request_id]: r_index] + return printable_text + + async def stream_output(self, cur_batch, tokenizer, next_ids): + cur_id = cur_batch.batch_id + cur_cached_ids = [] + _stream_tasks = [] + for index, request_id in enumerate(cur_batch.request_ids): + if not self.is_finish.get(request_id, False): + if self.token_cache.get(request_id, None) is None: + self.token_cache[request_id] = [] + self.print_len[request_id] = 0 + self.token_cache[request_id].extend(next_ids[index].tolist()) + cur_cached_ids.append(self.token_cache[request_id]) + + for index, request_id in enumerate(cur_batch.request_ids): + if not self.is_finish.get(request_id, False): + remain = cur_batch.max_tokens - len(self.tokens[cur_id]) + + if self.streamer.get(request_id, None) is None: + self.streamer[request_id] = asyncio.Queue() + + # Currently ignore eos for benchmark + # if next_ids[index].int() == tokenizer.eos_token_id: + # remain = 0 + # self.is_finish[request_id] = True + + cur_text = tokenizer.decode(self.token_cache[request_id]) + printable_text = self.get_printable_text(cur_text, request_id) + + if remain > 0: + _stream_tasks.append(self.streamer[request_id].put((remain, printable_text))) + else: + printable_text = printable_text + cur_text[self.print_len[request_id]:] + self.token_cache.pop(request_id, None) + self.print_len.pop(request_id, None) + _stream_tasks.append(self.streamer[request_id].put((remain, printable_text))) + await asyncio.gather(*_stream_tasks) + async def process_step(self, tokenizer, result_dict): cur_batch = None - + torch.xpu.synchronize(self.device) if self.rank == 0: - if self.send_buff is not None: - # logger.info(f"send {self.rank} {self.send_buff.shape}") - dist.send(self.send_buff, dst=self.next_rank) - if self.on_going_batches[0] is not None: cur_batch = self.on_going_batches[0] cur_input = None @@ -773,13 +831,13 @@ async def process_step(self, tokenizer, result_dict): # logger.info(f"recv {self.rank} {next_ids.shape}") dist.recv(next_ids, src=self.pre_rank) + torch.xpu.synchronize(self.device) if cur_batch.partial_prefilling > 0: cur_input = self.input_ids_dict[cur_batch.batch_id] else: if self.tokens.get(cur_id, None) is None: self.tokens[cur_id] = [] - if len(next_ids.shape) == 1: next_ids = next_ids.unsqueeze(0) self.tokens[cur_id].append(next_ids) @@ -788,44 +846,14 @@ async def process_step(self, tokenizer, result_dict): cur_batch.input_len = 1 cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths] - for index, request_id in enumerate(cur_batch.request_ids): - - if not self.is_finish.get(request_id, False): - remain = cur_batch.max_tokens - len(self.tokens[cur_id]) - - if self.streamer.get(request_id, None) is None: - self.streamer[request_id] = asyncio.Queue() - - # Currently ignore eos for benchmark - # if next_ids[index].int() == tokenizer.eos_token_id: - # remain = 0 - # self.is_finish[request_id] = True - - if self.token_cache.get(request_id, None) is None: - self.token_cache[request_id] = [] - self.print_len[request_id] = 0 - self.token_cache[request_id].extend(next_ids[index].tolist()) - - text = tokenizer.decode(self.token_cache[request_id]) - if text.endswith("\n"): - printable_text = text[self.print_len[request_id]:] - self.token_cache[request_id] = [] - self.print_len[request_id] = 0 - elif len(text) > 0 and _is_chinese_char(ord(text[-1])): - printable_text = text[self.print_len[request_id]:] - self.print_len[request_id] += len(printable_text) - else: - r_index = text.rfind(" ") + 1 - printable_text = text[self.print_len[request_id]: r_index] - self.print_len[request_id] += len(printable_text) - - if remain > 0: - await self.streamer[request_id].put((remain, printable_text)) - else: - printable_text = printable_text + text[self.print_len[request_id]:] - self.token_cache.pop(request_id, None) - self.print_len.pop(request_id, None) - await self.streamer[request_id].put((remain, printable_text)) + pre_task = self.stream_tasks.get(cur_id) + if pre_task is not None: + await pre_task + del self.stream_tasks[cur_id] + cur_task = asyncio.create_task( + self.stream_output(cur_batch, tokenizer, next_ids) + ) + self.stream_tasks[cur_id] = cur_task if len(self.tokens[cur_id]) >= cur_batch.max_tokens: # Finish a batch @@ -841,6 +869,7 @@ async def process_step(self, tokenizer, result_dict): next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1) logger.info(f"First token latency: {first_token}, " f"next token latency: {next_token}") + await self.wait_stream_output(cur_id) self.clear_batch(cur_id) cur_batch.stopped = True else: @@ -850,15 +879,12 @@ async def process_step(self, tokenizer, result_dict): if cur_batch is not None: cur_batch = self.prepare_batch(cur_batch) dist.broadcast_object_list([cur_batch], src=0) + else: + await asyncio.sleep(0) else: - if self.send_buff is not None: - # logger.info(f"send {self.rank} {self.send_buff.shape}") - dist.send(self.send_buff, dst=self.next_rank) - batch_list = [None] dist.broadcast_object_list(batch_list, src=0) - cur_batch = batch_list[0] cur_input = None @@ -882,10 +908,15 @@ async def process_step(self, tokenizer, result_dict): ) # logger.info(f"recv {self.rank} {cur_input.shape}") dist.recv(cur_input, src=self.pre_rank) + torch.xpu.synchronize(self.device) output, cur_batch = self.model_step(cur_input, cur_batch) - self.send_buff = output + torch.xpu.synchronize(self.device) + if self.send_buff is not None: + self.send_buff.wait() + if output is not None: + self.send_buff = dist.isend(output, dst=self.next_rank) if self.rank == 0: self.on_going_batches[:-1] = self.on_going_batches[1:]