Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangyuT committed Jul 16, 2024
1 parent 85c6e23 commit 69aed71
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ class BatchTask(BaseModel):

def make_attention_mask(prompt_lengths, device):
max_length = max(prompt_lengths)
attention_mask = torch.zeros((len(prompt_lengths), max_length), dtype=torch.int64, device=device)
attention_mask = torch.zeros((len(prompt_lengths), max_length),
dtype=torch.int64, device=device)
for i, length in enumerate(prompt_lengths):
attention_mask[i, max_length - length:] = 1
return attention_mask
Expand Down Expand Up @@ -661,7 +662,7 @@ def model_step(self, input, cur_batch):

if self.pp_config.is_tail:
_pre_output = self.partial_output_dict.get(cur_id, None)
tmp_output = output.logits.to(self.dtype)
tmp_output = output.logits
tmp_output = torch.argmax(tmp_output[:, -1:, :], dim=-1)
if _pre_output is None:
_pre_output = tmp_output
Expand All @@ -674,8 +675,10 @@ def model_step(self, input, cur_batch):
self.past_key_values_dict[cur_id] = _past_key_values
torch.xpu.synchronize()
if not self.pp_config.is_tail:
# TODO: check this .to()
return output[0].to(self.dtype), cur_batch
_output = output[0]
if _output.dtype != self.dtype:
_output = _output.to(self.dtype)
return _output, cur_batch
else:
if cur_batch.partial_prefilling > 0 and \
cur_batch.prefilled_index == cur_batch.batch_size:
Expand Down

0 comments on commit 69aed71

Please sign in to comment.