From c0f1be6aea5125626e17c5a963729e88f6a99bd4 Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Thu, 30 May 2024 16:40:59 +0800 Subject: [PATCH] Fix pp logic (#11175) * only send no none batch and rank1-n sending first * always send first --- .../Pipeline-Parallel-FastAPI/pipeline_models.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py index c96211e4604..8cf19793619 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py @@ -413,6 +413,10 @@ async def process_step(self, tokenizer, result_dict): cur_batch = None if self.rank == 0: + if self.send_buff is not None: + # logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}") + dist.send(self.send_buff, dst=self.next_rank) + if self.on_going_batches[0] is not None: cur_batch = self.on_going_batches[0] cur_input = None @@ -464,22 +468,20 @@ async def process_step(self, tokenizer, result_dict): if (cur_batch is not None) and cur_batch.stopped: cur_batch = None + if cur_batch is not None: + dist.broadcast_object_list([cur_batch], src=0) + + else: if self.send_buff is not None: # logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}") dist.send(self.send_buff, dst=self.next_rank) - dist.broadcast_object_list([cur_batch], src=0) - - else: + batch_list = [None] dist.broadcast_object_list(batch_list, src=0) cur_batch = batch_list[0] cur_input = None - if self.send_buff is not None: - # logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}") - dist.send(self.send_buff, dst=self.next_rank) - if cur_batch is not None: if cur_batch.stopped: self.clear_batch(cur_batch.batch_id)