From 1ab38c17e7a1adbf27ac0d7ea3954818470ae821 Mon Sep 17 00:00:00 2001 From: plusbang Date: Thu, 13 Jun 2024 00:31:10 +0800 Subject: [PATCH] fix code style --- python/llm/src/ipex_llm/transformers/pipeline_parallel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 3859e0e4cfd..dcef0104746 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -75,7 +75,7 @@ def init_pipeline_parallel(): def pipeline_parallel(model, pipeline_parallel_stages): slice_size = (model.config.num_hidden_layers + pipeline_parallel_stages - 1) // \ pipeline_parallel_stages - + local_rank = dist.get_rank() layer_start = slice_size * local_rank layer_end = layer_start + min(slice_size, model.config.num_hidden_layers - layer_start) @@ -141,7 +141,7 @@ def pipeline_parallel_generate(self, local_rank = dist.get_rank() pre_rank = (local_rank - 1) % self.pipeline_parallel_stages next_rank = (local_rank + 1) % self.pipeline_parallel_stages - + _input_ids = None _past_key_values = None bs = inputs.shape[0] @@ -168,7 +168,7 @@ def pipeline_parallel_generate(self, dist.send(outputs[0], dst=next_rank) next_ids = torch.empty((bs, 1), device=f'xpu:{local_rank}', dtype=torch.int64) dist.broadcast(next_ids, src=self.pipeline_parallel_stages - 1) - + _input_ids = next_ids output_ids = torch.cat([output_ids, next_ids], dim=-1) _past_key_values = outputs.past_key_values