Skip to content

Commit

Permalink
fix code style
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang committed Jun 12, 2024
1 parent d8c62e0 commit 1ab38c1
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def init_pipeline_parallel():
def pipeline_parallel(model, pipeline_parallel_stages):
slice_size = (model.config.num_hidden_layers + pipeline_parallel_stages - 1) // \
pipeline_parallel_stages

local_rank = dist.get_rank()
layer_start = slice_size * local_rank
layer_end = layer_start + min(slice_size, model.config.num_hidden_layers - layer_start)
Expand Down Expand Up @@ -141,7 +141,7 @@ def pipeline_parallel_generate(self,
local_rank = dist.get_rank()
pre_rank = (local_rank - 1) % self.pipeline_parallel_stages
next_rank = (local_rank + 1) % self.pipeline_parallel_stages

_input_ids = None
_past_key_values = None
bs = inputs.shape[0]
Expand All @@ -168,7 +168,7 @@ def pipeline_parallel_generate(self,
dist.send(outputs[0], dst=next_rank)
next_ids = torch.empty((bs, 1), device=f'xpu:{local_rank}', dtype=torch.int64)
dist.broadcast(next_ids, src=self.pipeline_parallel_stages - 1)

_input_ids = next_ids
output_ids = torch.cat([output_ids, next_ids], dim=-1)
_past_key_values = outputs.past_key_values
Expand Down

0 comments on commit 1ab38c1

Please sign in to comment.