Skip to content

Commit

Permalink
Fix pp logic (#11175)
Browse files Browse the repository at this point in the history
* only send no none batch and rank1-n sending first

* always send first
  • Loading branch information
hzjane authored May 30, 2024
1 parent 4127b99 commit c0f1be6
Showing 1 changed file with 9 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,10 @@ async def process_step(self, tokenizer, result_dict):
cur_batch = None

if self.rank == 0:
if self.send_buff is not None:
# logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)

if self.on_going_batches[0] is not None:
cur_batch = self.on_going_batches[0]
cur_input = None
Expand Down Expand Up @@ -464,22 +468,20 @@ async def process_step(self, tokenizer, result_dict):
if (cur_batch is not None) and cur_batch.stopped:
cur_batch = None

if cur_batch is not None:
dist.broadcast_object_list([cur_batch], src=0)

else:
if self.send_buff is not None:
# logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)
dist.broadcast_object_list([cur_batch], src=0)

else:

batch_list = [None]
dist.broadcast_object_list(batch_list, src=0)

cur_batch = batch_list[0]
cur_input = None

if self.send_buff is not None:
# logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)

if cur_batch is not None:
if cur_batch.stopped:
self.clear_batch(cur_batch.batch_id)
Expand Down

0 comments on commit c0f1be6

Please sign in to comment.